magic_model.py 42.5 KB
Newer Older
icecraft's avatar
icecraft committed
1
import enum
liukaiwen's avatar
liukaiwen committed
2
import json
kernel.h@qq.com's avatar
kernel.h@qq.com committed
3

4
5
from magic_pdf.config.model_block_type import ModelBlockTypeEnum
from magic_pdf.config.ocr_content_type import CategoryId, ContentType
6
7
from magic_pdf.data.data_reader_writer import (FileBasedDataReader,
                                               FileBasedDataWriter)
8
from magic_pdf.data.dataset import Dataset
9
from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
10
11
12
                                    bbox_relative_pos, box_area, calculate_iou,
                                    calculate_overlap_area_in_bbox1_area_ratio,
                                    get_overlap_area)
13
from magic_pdf.libs.commons import fitz, join_path
liukaiwen's avatar
liukaiwen committed
14
from magic_pdf.libs.coordinate_transform import get_scale_ratio
15
from magic_pdf.libs.local_math import float_gt
16
from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
liukaiwen's avatar
liukaiwen committed
17

18
CAPATION_OVERLAP_AREA_RATIO = 0.6
19
MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
liukaiwen's avatar
liukaiwen committed
20

许瑞's avatar
许瑞 committed
21

icecraft's avatar
icecraft committed
22
23
24
25
26
27
28
29
class PosRelationEnum(enum.Enum):
    LEFT = 'left'
    RIGHT = 'right'
    UP = 'up'
    BOTTOM = 'bottom'
    ALL = 'all'


liukaiwen's avatar
liukaiwen committed
30
class MagicModel:
31
    """每个函数没有得到元素的时候返回空list."""
liukaiwen's avatar
liukaiwen committed
32
33
34

    def __fix_axis(self):
        for model_page_info in self.__model_list:
35
            need_remove_list = []
36
            page_no = model_page_info['page_info']['page_no']
liukaiwen's avatar
liukaiwen committed
37
            horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
38
                model_page_info, self.__docs.get_page(page_no)
liukaiwen's avatar
liukaiwen committed
39
            )
40
            layout_dets = model_page_info['layout_dets']
liukaiwen's avatar
liukaiwen committed
41
            for layout_det in layout_dets:
42

43
                if layout_det.get('bbox') is not None:
44
                    # 兼容直接输出bbox的模型数据,如paddle
45
                    x0, y0, x1, y1 = layout_det['bbox']
46
47
                else:
                    # 兼容直接输出poly的模型数据,如xxx
48
                    x0, y0, _, _, x1, y1, _, _ = layout_det['poly']
49

liukaiwen's avatar
liukaiwen committed
50
51
52
53
54
55
                bbox = [
                    int(x0 / horizontal_scale_ratio),
                    int(y0 / vertical_scale_ratio),
                    int(x1 / horizontal_scale_ratio),
                    int(y1 / vertical_scale_ratio),
                ]
56
                layout_det['bbox'] = bbox
57
58
                # 删除高度或者宽度小于等于0的spans
                if bbox[2] - bbox[0] <= 0 or bbox[3] - bbox[1] <= 0:
liukaiwen's avatar
liukaiwen committed
59
60
61
62
                    need_remove_list.append(layout_det)
            for need_remove in need_remove_list:
                layout_dets.remove(need_remove)

63
    def __fix_by_remove_low_confidence(self):
64
65
        for model_page_info in self.__model_list:
            need_remove_list = []
66
            layout_dets = model_page_info['layout_dets']
67
            for layout_det in layout_dets:
68
                if layout_det['score'] <= 0.05:
69
70
71
72
73
                    need_remove_list.append(layout_det)
                else:
                    continue
            for need_remove in need_remove_list:
                layout_dets.remove(need_remove)
liukaiwen's avatar
liukaiwen committed
74

75
76
77
    def __fix_by_remove_high_iou_and_low_confidence(self):
        for model_page_info in self.__model_list:
            need_remove_list = []
78
            layout_dets = model_page_info['layout_dets']
79
80
81
82
            for layout_det1 in layout_dets:
                for layout_det2 in layout_dets:
                    if layout_det1 == layout_det2:
                        continue
83
                    if layout_det1['category_id'] in [
blue's avatar
blue committed
84
85
86
87
88
89
90
91
92
93
                        0,
                        1,
                        2,
                        3,
                        4,
                        5,
                        6,
                        7,
                        8,
                        9,
94
                    ] and layout_det2['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
blue's avatar
blue committed
95
                        if (
96
                            calculate_iou(layout_det1['bbox'], layout_det2['bbox'])
blue's avatar
blue committed
97
98
                            > 0.9
                        ):
99
                            if layout_det1['score'] < layout_det2['score']:
100
101
102
103
104
105
106
107
108
109
110
111
112
                                layout_det_need_remove = layout_det1
                            else:
                                layout_det_need_remove = layout_det2

                            if layout_det_need_remove not in need_remove_list:
                                need_remove_list.append(layout_det_need_remove)
                        else:
                            continue
                    else:
                        continue
            for need_remove in need_remove_list:
                layout_dets.remove(need_remove)

113
    def __init__(self, model_list: list, docs: Dataset):
kernel.h@qq.com's avatar
kernel.h@qq.com committed
114
        self.__model_list = model_list
liukaiwen's avatar
liukaiwen committed
115
        self.__docs = docs
blue's avatar
blue committed
116
        """为所有模型数据添加bbox信息(缩放,poly->bbox)"""
kernel.h@qq.com's avatar
kernel.h@qq.com committed
117
        self.__fix_axis()
blue's avatar
blue committed
118
        """删除置信度特别低的模型数据(<0.05),提高质量"""
119
        self.__fix_by_remove_low_confidence()
blue's avatar
blue committed
120
        """删除高iou(>0.9)数据中置信度较低的那个"""
121
        self.__fix_by_remove_high_iou_and_low_confidence()
122
123
        self.__fix_footnote()

124
125
126
127
128
129
130
131
132
    def _bbox_distance(self, bbox1, bbox2):
        left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
        flags = [left, right, bottom, top]
        count = sum([1 if v else 0 for v in flags])
        if count > 1:
            return float('inf')
        if left or right:
            l1 = bbox1[3] - bbox1[1]
            l2 = bbox2[3] - bbox2[1]
133
        else:
134
135
            l1 = bbox1[2] - bbox1[0]
            l2 = bbox2[2] - bbox2[0]
136

137
        if l2 > l1 and (l2 - l1) / l1 > 0.3:
138
139
            return float('inf')

140
141
        return bbox_distance(bbox1, bbox2)

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    def __fix_footnote(self):
        # 3: figure, 5: table, 7: footnote
        for model_page_info in self.__model_list:
            footnotes = []
            figures = []
            tables = []

            for obj in model_page_info['layout_dets']:
                if obj['category_id'] == 7:
                    footnotes.append(obj)
                elif obj['category_id'] == 3:
                    figures.append(obj)
                elif obj['category_id'] == 5:
                    tables.append(obj)
                if len(footnotes) * len(figures) == 0:
                    continue
158
159
160
161
162
163
164
165
166
167
168
169
            dis_figure_footnote = {}
            dis_table_footnote = {}

            for i in range(len(footnotes)):
                for j in range(len(figures)):
                    pos_flag_count = sum(
                        list(
                            map(
                                lambda x: 1 if x else 0,
                                bbox_relative_pos(
                                    footnotes[i]['bbox'], figures[j]['bbox']
                                ),
170
171
                            )
                        )
172
173
174
175
                    )
                    if pos_flag_count > 1:
                        continue
                    dis_figure_footnote[i] = min(
176
                        self._bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
177
178
179
180
181
182
183
184
185
186
187
                        dis_figure_footnote.get(i, float('inf')),
                    )
            for i in range(len(footnotes)):
                for j in range(len(tables)):
                    pos_flag_count = sum(
                        list(
                            map(
                                lambda x: 1 if x else 0,
                                bbox_relative_pos(
                                    footnotes[i]['bbox'], tables[j]['bbox']
                                ),
188
189
                            )
                        )
190
191
192
                    )
                    if pos_flag_count > 1:
                        continue
193

194
                    dis_table_footnote[i] = min(
195
                        self._bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
196
197
198
199
200
201
202
                        dis_table_footnote.get(i, float('inf')),
                    )
            for i in range(len(footnotes)):
                if i not in dis_figure_footnote:
                    continue
                if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
                    footnotes[i]['category_id'] = CategoryId.ImageFootnote
liukaiwen's avatar
liukaiwen committed
203
204
205
206
207
208
209
210

    def __reduct_overlap(self, bboxes):
        N = len(bboxes)
        keep = [True] * N
        for i in range(N):
            for j in range(N):
                if i == j:
                    continue
211
                if _is_in(bboxes[i]['bbox'], bboxes[j]['bbox']):
liukaiwen's avatar
liukaiwen committed
212
213
214
215
                    keep[i] = False
        return [bboxes[i] for i in range(N) if keep[i]]

    def __tie_up_category_by_distance(
216
        self, page_no, subject_category_id, object_category_id
liukaiwen's avatar
liukaiwen committed
217
    ):
218
219
        """假定每个 subject 最多有一个 object (可以有多个相邻的 object 合并为单个 object),每个 object
        只能属于一个 subject."""
liukaiwen's avatar
liukaiwen committed
220
        ret = []
221
        MAX_DIS_OF_POINT = 10**9 + 7
222
223
224
225
226
        """
        subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。
        筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
        再求出筛选出的 subjects 和 object 的最短距离
        """
227
228

        def search_overlap_between_boxes(subject_idx, object_idx):
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
            idxes = [subject_idx, object_idx]
            x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
            y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
            x1s = [all_bboxes[idx]['bbox'][2] for idx in idxes]
            y1s = [all_bboxes[idx]['bbox'][3] for idx in idxes]

            merged_bbox = [
                min(x0s),
                min(y0s),
                max(x1s),
                max(y1s),
            ]
            ratio = 0

            other_objects = list(
                map(
                    lambda x: {'bbox': x['bbox'], 'score': x['score']},
                    filter(
                        lambda x: x['category_id']
                        not in (object_category_id, subject_category_id),
                        self.__model_list[page_no]['layout_dets'],
                    ),
                )
            )
            for other_object in other_objects:
                ratio = max(
                    ratio,
256
257
258
                    get_overlap_area(merged_bbox, other_object['bbox'])
                    * 1.0
                    / box_area(all_bboxes[object_idx]['bbox']),
259
260
261
262
263
                )
                if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO:
                    break

            return ratio
liukaiwen's avatar
liukaiwen committed
264

许瑞's avatar
许瑞 committed
265
        def may_find_other_nearest_bbox(subject_idx, object_idx):
266
            ret = float('inf')
267

许瑞's avatar
许瑞 committed
268
            x0 = min(
269
                all_bboxes[subject_idx]['bbox'][0], all_bboxes[object_idx]['bbox'][0]
许瑞's avatar
许瑞 committed
270
271
            )
            y0 = min(
272
                all_bboxes[subject_idx]['bbox'][1], all_bboxes[object_idx]['bbox'][1]
许瑞's avatar
许瑞 committed
273
274
            )
            x1 = max(
275
                all_bboxes[subject_idx]['bbox'][2], all_bboxes[object_idx]['bbox'][2]
许瑞's avatar
许瑞 committed
276
277
            )
            y1 = max(
278
                all_bboxes[subject_idx]['bbox'][3], all_bboxes[object_idx]['bbox'][3]
许瑞's avatar
许瑞 committed
279
            )
许瑞's avatar
许瑞 committed
280

许瑞's avatar
许瑞 committed
281
            object_area = abs(
282
                all_bboxes[object_idx]['bbox'][2] - all_bboxes[object_idx]['bbox'][0]
许瑞's avatar
许瑞 committed
283
            ) * abs(
284
                all_bboxes[object_idx]['bbox'][3] - all_bboxes[object_idx]['bbox'][1]
许瑞's avatar
许瑞 committed
285
            )
许瑞's avatar
许瑞 committed
286
287

            for i in range(len(all_bboxes)):
许瑞's avatar
许瑞 committed
288
289
                if (
                    i == subject_idx
290
                    or all_bboxes[i]['category_id'] != subject_category_id
许瑞's avatar
许瑞 committed
291
                ):
许瑞's avatar
许瑞 committed
292
                    continue
293
294
                if _is_part_overlap([x0, y0, x1, y1], all_bboxes[i]['bbox']) or _is_in(
                    all_bboxes[i]['bbox'], [x0, y0, x1, y1]
许瑞's avatar
许瑞 committed
295
                ):
296

许瑞's avatar
许瑞 committed
297
                    i_area = abs(
298
299
                        all_bboxes[i]['bbox'][2] - all_bboxes[i]['bbox'][0]
                    ) * abs(all_bboxes[i]['bbox'][3] - all_bboxes[i]['bbox'][1])
许瑞's avatar
许瑞 committed
300
                    if i_area >= object_area:
301
                        ret = min(float('inf'), dis[i][object_idx])
302

许瑞's avatar
许瑞 committed
303
304
            return ret

blue's avatar
blue committed
305
        def expand_bbbox(idxes):
306
307
308
309
            x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
            y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
            x1s = [all_bboxes[idx]['bbox'][2] for idx in idxes]
            y1s = [all_bboxes[idx]['bbox'][3] for idx in idxes]
blue's avatar
blue committed
310
311
            return min(x0s), min(y0s), max(x1s), max(y1s)

liukaiwen's avatar
liukaiwen committed
312
313
314
        subjects = self.__reduct_overlap(
            list(
                map(
315
                    lambda x: {'bbox': x['bbox'], 'score': x['score']},
liukaiwen's avatar
liukaiwen committed
316
                    filter(
317
318
                        lambda x: x['category_id'] == subject_category_id,
                        self.__model_list[page_no]['layout_dets'],
liukaiwen's avatar
liukaiwen committed
319
320
321
322
323
324
325
326
                    ),
                )
            )
        )

        objects = self.__reduct_overlap(
            list(
                map(
327
                    lambda x: {'bbox': x['bbox'], 'score': x['score']},
liukaiwen's avatar
liukaiwen committed
328
                    filter(
329
330
                        lambda x: x['category_id'] == object_category_id,
                        self.__model_list[page_no]['layout_dets'],
liukaiwen's avatar
liukaiwen committed
331
332
333
334
335
336
                    ),
                )
            )
        )
        subject_object_relation_map = {}

许瑞's avatar
许瑞 committed
337
        subjects.sort(
338
            key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2
许瑞's avatar
许瑞 committed
339
        )  # get the distance !
liukaiwen's avatar
liukaiwen committed
340
341
342
343

        all_bboxes = []

        for v in subjects:
许瑞's avatar
许瑞 committed
344
345
            all_bboxes.append(
                {
346
347
348
                    'category_id': subject_category_id,
                    'bbox': v['bbox'],
                    'score': v['score'],
许瑞's avatar
许瑞 committed
349
350
                }
            )
liukaiwen's avatar
liukaiwen committed
351
352

        for v in objects:
许瑞's avatar
许瑞 committed
353
354
            all_bboxes.append(
                {
355
356
357
                    'category_id': object_category_id,
                    'bbox': v['bbox'],
                    'score': v['score'],
许瑞's avatar
许瑞 committed
358
359
                }
            )
liukaiwen's avatar
liukaiwen committed
360
361
362
363
364
365
366

        N = len(all_bboxes)
        dis = [[MAX_DIS_OF_POINT] * N for _ in range(N)]

        for i in range(N):
            for j in range(i):
                if (
367
368
                    all_bboxes[i]['category_id'] == subject_category_id
                    and all_bboxes[j]['category_id'] == subject_category_id
liukaiwen's avatar
liukaiwen committed
369
370
371
                ):
                    continue

372
373
374
375
                subject_idx, object_idx = i, j
                if all_bboxes[j]['category_id'] == subject_category_id:
                    subject_idx, object_idx = j, i

376
377
378
379
                if (
                    search_overlap_between_boxes(subject_idx, object_idx)
                    >= MERGE_BOX_OVERLAP_AREA_RATIO
                ):
380
381
382
383
                    dis[i][j] = float('inf')
                    dis[j][i] = dis[i][j]
                    continue

384
385
386
                dis[i][j] = self._bbox_distance(
                    all_bboxes[subject_idx]['bbox'], all_bboxes[object_idx]['bbox']
                )
liukaiwen's avatar
liukaiwen committed
387
388
389
390
391
                dis[j][i] = dis[i][j]

        used = set()
        for i in range(N):
            # 求第 i 个 subject 所关联的 object
392
            if all_bboxes[i]['category_id'] != subject_category_id:
liukaiwen's avatar
liukaiwen committed
393
394
395
396
397
398
399
400
401
402
403
                continue
            seen = set()
            candidates = []
            arr = []
            for j in range(N):

                pos_flag_count = sum(
                    list(
                        map(
                            lambda x: 1 if x else 0,
                            bbox_relative_pos(
404
                                all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
liukaiwen's avatar
liukaiwen committed
405
406
407
408
409
410
411
                            ),
                        )
                    )
                )
                if pos_flag_count > 1:
                    continue
                if (
412
                    all_bboxes[j]['category_id'] != object_category_id
413
414
                    or j in used
                    or dis[i][j] == MAX_DIS_OF_POINT
liukaiwen's avatar
liukaiwen committed
415
416
                ):
                    continue
blue's avatar
blue committed
417
                left, right, _, _ = bbox_relative_pos(
418
                    all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
blue's avatar
blue committed
419
                )  # 由  pos_flag_count 相关逻辑保证本段逻辑准确性
许瑞's avatar
许瑞 committed
420
                if left or right:
421
                    one_way_dis = all_bboxes[i]['bbox'][2] - all_bboxes[i]['bbox'][0]
许瑞's avatar
许瑞 committed
422
                else:
423
                    one_way_dis = all_bboxes[i]['bbox'][3] - all_bboxes[i]['bbox'][1]
许瑞's avatar
许瑞 committed
424
425
                if dis[i][j] > one_way_dis:
                    continue
liukaiwen's avatar
liukaiwen committed
426
427
428
429
                arr.append((dis[i][j], j))

            arr.sort(key=lambda x: x[0])
            if len(arr) > 0:
430
431
432
433
                """
                bug: 离该subject 最近的 object 可能跨越了其它的 subject。
                比如 [this subect] [some sbuject] [the nearest object of subject]
                """
许瑞's avatar
许瑞 committed
434
                if may_find_other_nearest_bbox(i, arr[0][1]) >= arr[0][0]:
许瑞's avatar
许瑞 committed
435

许瑞's avatar
许瑞 committed
436
437
                    candidates.append(arr[0][1])
                    seen.add(arr[0][1])
liukaiwen's avatar
liukaiwen committed
438
439
440
441
442
443
444
445
446
447

            # 已经获取初始种子
            for j in set(candidates):
                tmp = []
                for k in range(i + 1, N):
                    pos_flag_count = sum(
                        list(
                            map(
                                lambda x: 1 if x else 0,
                                bbox_relative_pos(
448
                                    all_bboxes[j]['bbox'], all_bboxes[k]['bbox']
liukaiwen's avatar
liukaiwen committed
449
450
451
452
453
454
455
456
457
                                ),
                            )
                        )
                    )

                    if pos_flag_count > 1:
                        continue

                    if (
458
                        all_bboxes[k]['category_id'] != object_category_id
459
460
461
                        or k in used
                        or k in seen
                        or dis[j][k] == MAX_DIS_OF_POINT
许瑞's avatar
许瑞 committed
462
                        or dis[j][k] > dis[i][j]
liukaiwen's avatar
liukaiwen committed
463
464
                    ):
                        continue
许瑞's avatar
许瑞 committed
465

liukaiwen's avatar
liukaiwen committed
466
                    is_nearest = True
467
468
                    for ni in range(i + 1, N):
                        if ni in (j, k) or ni in used or ni in seen:
liukaiwen's avatar
liukaiwen committed
469
470
                            continue

471
                        if not float_gt(dis[ni][k], dis[j][k]):
liukaiwen's avatar
liukaiwen committed
472
473
474
475
                            is_nearest = False
                            break

                    if is_nearest:
blue's avatar
blue committed
476
                        nx0, ny0, nx1, ny1 = expand_bbbox(list(seen) + [k])
477
                        n_dis = bbox_distance(
478
479
                            all_bboxes[i]['bbox'], [nx0, ny0, nx1, ny1]
                        )
blue's avatar
blue committed
480
481
                        if float_gt(dis[i][j], n_dis):
                            continue
liukaiwen's avatar
liukaiwen committed
482
483
484
485
486
487
488
489
490
                        tmp.append(k)
                        seen.add(k)

                candidates = tmp
                if len(candidates) == 0:
                    break

            # 已经获取到某个 figure 下所有的最靠近的 captions,以及最靠近这些 captions 的 captions 。
            # 先扩一下 bbox,
blue's avatar
blue committed
491
            ox0, oy0, ox1, oy1 = expand_bbbox(list(seen) + [i])
492
            ix0, iy0, ix1, iy1 = all_bboxes[i]['bbox']
liukaiwen's avatar
liukaiwen committed
493
494
495
496
497
498
499
500
501
502
503
504
505

            # 分成了 4 个截取空间,需要计算落在每个截取空间下 objects 合并后占据的矩形面积
            caption_poses = [
                [ox0, oy0, ix0, oy1],
                [ox0, oy0, ox1, iy0],
                [ox0, iy1, ox1, oy1],
                [ix1, oy0, ox1, oy1],
            ]

            caption_areas = []
            for bbox in caption_poses:
                embed_arr = []
                for idx in seen:
许瑞's avatar
许瑞 committed
506
507
                    if (
                        calculate_overlap_area_in_bbox1_area_ratio(
508
                            all_bboxes[idx]['bbox'], bbox
许瑞's avatar
许瑞 committed
509
510
511
                        )
                        > CAPATION_OVERLAP_AREA_RATIO
                    ):
liukaiwen's avatar
liukaiwen committed
512
513
514
                        embed_arr.append(idx)

                if len(embed_arr) > 0:
515
516
517
518
                    embed_x0 = min([all_bboxes[idx]['bbox'][0] for idx in embed_arr])
                    embed_y0 = min([all_bboxes[idx]['bbox'][1] for idx in embed_arr])
                    embed_x1 = max([all_bboxes[idx]['bbox'][2] for idx in embed_arr])
                    embed_y1 = max([all_bboxes[idx]['bbox'][3] for idx in embed_arr])
liukaiwen's avatar
liukaiwen committed
519
520
521
522
523
524
525
526
527
528
529
530
                    caption_areas.append(
                        int(abs(embed_x1 - embed_x0) * abs(embed_y1 - embed_y0))
                    )
                else:
                    caption_areas.append(0)

            subject_object_relation_map[i] = []
            if max(caption_areas) > 0:
                max_area_idx = caption_areas.index(max(caption_areas))
                caption_bbox = caption_poses[max_area_idx]

                for j in seen:
许瑞's avatar
许瑞 committed
531
532
                    if (
                        calculate_overlap_area_in_bbox1_area_ratio(
533
                            all_bboxes[j]['bbox'], caption_bbox
许瑞's avatar
许瑞 committed
534
535
536
                        )
                        > CAPATION_OVERLAP_AREA_RATIO
                    ):
liukaiwen's avatar
liukaiwen committed
537
538
539
540
541
                        used.add(j)
                        subject_object_relation_map[i].append(j)

        for i in sorted(subject_object_relation_map.keys()):
            result = {
542
543
544
                'subject_body': all_bboxes[i]['bbox'],
                'all': all_bboxes[i]['bbox'],
                'score': all_bboxes[i]['score'],
liukaiwen's avatar
liukaiwen committed
545
546
547
548
            }

            if len(subject_object_relation_map[i]) > 0:
                x0 = min(
549
                    [all_bboxes[j]['bbox'][0] for j in subject_object_relation_map[i]]
liukaiwen's avatar
liukaiwen committed
550
551
                )
                y0 = min(
552
                    [all_bboxes[j]['bbox'][1] for j in subject_object_relation_map[i]]
liukaiwen's avatar
liukaiwen committed
553
554
                )
                x1 = max(
555
                    [all_bboxes[j]['bbox'][2] for j in subject_object_relation_map[i]]
liukaiwen's avatar
liukaiwen committed
556
557
                )
                y1 = max(
558
                    [all_bboxes[j]['bbox'][3] for j in subject_object_relation_map[i]]
liukaiwen's avatar
liukaiwen committed
559
                )
560
561
562
563
564
565
                result['object_body'] = [x0, y0, x1, y1]
                result['all'] = [
                    min(x0, all_bboxes[i]['bbox'][0]),
                    min(y0, all_bboxes[i]['bbox'][1]),
                    max(x1, all_bboxes[i]['bbox'][2]),
                    max(y1, all_bboxes[i]['bbox'][3]),
liukaiwen's avatar
liukaiwen committed
566
567
568
569
570
571
572
                ]
            ret.append(result)

        total_subject_object_dis = 0
        # 计算已经配对的 distance 距离
        for i in subject_object_relation_map.keys():
            for j in subject_object_relation_map[i]:
573
                total_subject_object_dis += bbox_distance(
574
                    all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
liukaiwen's avatar
liukaiwen committed
575
576
577
578
579
580
581
582
583
584
585
                )

        # 计算未匹配的 subject 和 object 的距离(非精确版)
        with_caption_subject = set(
            [
                key
                for key in subject_object_relation_map.keys()
                if len(subject_object_relation_map[i]) > 0
            ]
        )
        for i in range(N):
586
            if all_bboxes[i]['category_id'] != object_category_id or i in used:
liukaiwen's avatar
liukaiwen committed
587
588
589
590
                continue
            candidates = []
            for j in range(N):
                if (
591
                    all_bboxes[j]['category_id'] != subject_category_id
592
                    or j in with_caption_subject
liukaiwen's avatar
liukaiwen committed
593
594
595
596
597
598
599
600
601
                ):
                    continue
                candidates.append((dis[i][j], j))
            if len(candidates) > 0:
                candidates.sort(key=lambda x: x[0])
                total_subject_object_dis += candidates[0][1]
                with_caption_subject.add(j)
        return ret, total_subject_object_dis

602
    def __tie_up_category_by_distance_v2(
icecraft's avatar
icecraft committed
603
604
605
606
607
        self,
        page_no: int,
        subject_category_id: int,
        object_category_id: int,
        priority_pos: PosRelationEnum,
608
    ):
icecraft's avatar
icecraft committed
609
        """_summary_
610

icecraft's avatar
icecraft committed
611
612
613
614
615
616
617
618
619
        Args:
            page_no (int): _description_
            subject_category_id (int): _description_
            object_category_id (int): _description_
            priority_pos (PosRelationEnum): _description_

        Returns:
            _type_: _description_
        """
icecraft's avatar
icecraft committed
620
        AXIS_MULPLICITY = 0.5
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
        subjects = self.__reduct_overlap(
            list(
                map(
                    lambda x: {'bbox': x['bbox'], 'score': x['score']},
                    filter(
                        lambda x: x['category_id'] == subject_category_id,
                        self.__model_list[page_no]['layout_dets'],
                    ),
                )
            )
        )

        objects = self.__reduct_overlap(
            list(
                map(
                    lambda x: {'bbox': x['bbox'], 'score': x['score']},
                    filter(
                        lambda x: x['category_id'] == object_category_id,
                        self.__model_list[page_no]['layout_dets'],
                    ),
                )
            )
        )
644
        M = len(objects)
645
646
647

        subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
        objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
648
649
650
651
652
653
654
655
656
657

        sub_obj_map_h = {i: [] for i in range(len(subjects))}

        dis_by_directions = {
            'top': [[-1, float('inf')]] * M,
            'bottom': [[-1, float('inf')]] * M,
            'left': [[-1, float('inf')]] * M,
            'right': [[-1, float('inf')]] * M,
        }

658
        for i, obj in enumerate(objects):
659
660
661
662
663
            l_x_axis, l_y_axis = (
                obj['bbox'][2] - obj['bbox'][0],
                obj['bbox'][3] - obj['bbox'][1],
            )
            axis_unit = min(l_x_axis, l_y_axis)
664
665
            for j, sub in enumerate(subjects):

icecraft's avatar
icecraft committed
666
667
                bbox1, bbox2, _ = _remove_overlap_between_bbox(
                    objects[i]['bbox'], subjects[j]['bbox']
668
                )
icecraft's avatar
icecraft committed
669
                left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
670
671
                flags = [left, right, bottom, top]
                if sum([1 if v else 0 for v in flags]) > 1:
672
673
                    continue

674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
                if left:
                    if dis_by_directions['left'][i][1] > bbox_distance(
                        obj['bbox'], sub['bbox']
                    ):
                        dis_by_directions['left'][i] = [
                            j,
                            bbox_distance(obj['bbox'], sub['bbox']),
                        ]
                if right:
                    if dis_by_directions['right'][i][1] > bbox_distance(
                        obj['bbox'], sub['bbox']
                    ):
                        dis_by_directions['right'][i] = [
                            j,
                            bbox_distance(obj['bbox'], sub['bbox']),
                        ]
                if bottom:
                    if dis_by_directions['bottom'][i][1] > bbox_distance(
                        obj['bbox'], sub['bbox']
                    ):
                        dis_by_directions['bottom'][i] = [
                            j,
                            bbox_distance(obj['bbox'], sub['bbox']),
                        ]
                if top:
                    if dis_by_directions['top'][i][1] > bbox_distance(
                        obj['bbox'], sub['bbox']
                    ):
                        dis_by_directions['top'][i] = [
                            j,
                            bbox_distance(obj['bbox'], sub['bbox']),
                        ]
icecraft's avatar
icecraft committed
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726

            if (
                dis_by_directions['top'][i][1] != float('inf')
                and dis_by_directions['bottom'][i][1] != float('inf')
                and priority_pos in (PosRelationEnum.BOTTOM, PosRelationEnum.UP)
            ):
                RATIO = 3
                if (
                    abs(
                        dis_by_directions['top'][i][1]
                        - dis_by_directions['bottom'][i][1]
                    )
                    < RATIO * axis_unit
                ):

                    if priority_pos == PosRelationEnum.BOTTOM:
                        sub_obj_map_h[dis_by_directions['bottom'][i][0]].append(i)
                    else:
                        sub_obj_map_h[dis_by_directions['top'][i][0]].append(i)
                    continue

727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
            if dis_by_directions['left'][i][1] != float('inf') or dis_by_directions[
                'right'
            ][i][1] != float('inf'):
                if dis_by_directions['left'][i][1] != float(
                    'inf'
                ) and dis_by_directions['right'][i][1] != float('inf'):
                    if AXIS_MULPLICITY * axis_unit >= abs(
                        dis_by_directions['left'][i][1]
                        - dis_by_directions['right'][i][1]
                    ):
                        left_sub_bbox = subjects[dis_by_directions['left'][i][0]][
                            'bbox'
                        ]
                        right_sub_bbox = subjects[dis_by_directions['right'][i][0]][
                            'bbox'
                        ]

                        left_sub_bbox_y_axis = left_sub_bbox[3] - left_sub_bbox[1]
                        right_sub_bbox_y_axis = right_sub_bbox[3] - right_sub_bbox[1]

icecraft's avatar
icecraft committed
747
748
749
750
751
                        if (
                            abs(left_sub_bbox_y_axis - l_y_axis)
                            + dis_by_directions['left'][i][0]
                            > abs(right_sub_bbox_y_axis - l_y_axis)
                            + dis_by_directions['right'][i][0]
752
753
754
755
756
757
                        ):
                            left_or_right = dis_by_directions['right'][i]
                        else:
                            left_or_right = dis_by_directions['left'][i]
                    else:
                        left_or_right = dis_by_directions['left'][i]
icecraft's avatar
icecraft committed
758
                        if left_or_right[1] > dis_by_directions['right'][i][1]:
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
                            left_or_right = dis_by_directions['right'][i]
                else:
                    left_or_right = dis_by_directions['left'][i]
                    if left_or_right[1] == float('inf'):
                        left_or_right = dis_by_directions['right'][i]
            else:
                left_or_right = [-1, float('inf')]

            if dis_by_directions['top'][i][1] != float('inf') or dis_by_directions[
                'bottom'
            ][i][1] != float('inf'):
                if dis_by_directions['top'][i][1] != float('inf') and dis_by_directions[
                    'bottom'
                ][i][1] != float('inf'):
                    if AXIS_MULPLICITY * axis_unit >= abs(
                        dis_by_directions['top'][i][1]
                        - dis_by_directions['bottom'][i][1]
                    ):
                        top_bottom = subjects[dis_by_directions['bottom'][i][0]]['bbox']
                        bottom_top = subjects[dis_by_directions['top'][i][0]]['bbox']

                        top_bottom_x_axis = top_bottom[2] - top_bottom[0]
                        bottom_top_x_axis = bottom_top[2] - bottom_top[0]
icecraft's avatar
icecraft committed
782
783
784
785
786
787
                        if (
                            abs(top_bottom_x_axis - l_x_axis)
                            + dis_by_directions['bottom'][i][1]
                            > abs(bottom_top_x_axis - l_x_axis)
                            + dis_by_directions['top'][i][1]
                        ):
788
                            top_or_bottom = dis_by_directions['top'][i]
icecraft's avatar
icecraft committed
789
790
                        else:
                            top_or_bottom = dis_by_directions['bottom'][i]
791
792
                    else:
                        top_or_bottom = dis_by_directions['top'][i]
icecraft's avatar
icecraft committed
793
                        if top_or_bottom[1] > dis_by_directions['bottom'][i][1]:
794
795
796
797
798
                            top_or_bottom = dis_by_directions['bottom'][i]
                else:
                    top_or_bottom = dis_by_directions['top'][i]
                    if top_or_bottom[1] == float('inf'):
                        top_or_bottom = dis_by_directions['bottom'][i]
799
            else:
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
                top_or_bottom = [-1, float('inf')]

            if left_or_right[1] != float('inf') or top_or_bottom[1] != float('inf'):
                if left_or_right[1] != float('inf') and top_or_bottom[1] != float(
                    'inf'
                ):
                    if AXIS_MULPLICITY * axis_unit >= abs(
                        left_or_right[1] - top_or_bottom[1]
                    ):
                        y_axis_bbox = subjects[left_or_right[0]]['bbox']
                        x_axis_bbox = subjects[top_or_bottom[0]]['bbox']

                        if (
                            abs((x_axis_bbox[2] - x_axis_bbox[0]) - l_x_axis) / l_x_axis
                            > abs((y_axis_bbox[3] - y_axis_bbox[1]) - l_y_axis)
                            / l_y_axis
                        ):
                            sub_obj_map_h[left_or_right[0]].append(i)
                        else:
                            sub_obj_map_h[top_or_bottom[0]].append(i)
                    else:
                        if left_or_right[1] > top_or_bottom[1]:
                            sub_obj_map_h[top_or_bottom[0]].append(i)
                        else:
                            sub_obj_map_h[left_or_right[0]].append(i)
                else:
                    if left_or_right[1] != float('inf'):
                        sub_obj_map_h[left_or_right[0]].append(i)
                    else:
                        sub_obj_map_h[top_or_bottom[0]].append(i)
830
831
832
833
        ret = []
        for i in sub_obj_map_h.keys():
            ret.append(
                {
icecraft's avatar
icecraft committed
834
835
836
837
                    'sub_bbox': {
                        'bbox': subjects[i]['bbox'],
                        'score': subjects[i]['score'],
                    },
838
839
840
841
                    'obj_bboxes': [
                        {'score': objects[j]['score'], 'bbox': objects[j]['bbox']}
                        for j in sub_obj_map_h[i]
                    ],
842
843
844
845
846
847
                    'sub_idx': i,
                }
            )
        return ret

    def get_imgs_v2(self, page_no: int):
icecraft's avatar
icecraft committed
848
849
850
        with_captions = self.__tie_up_category_by_distance_v2(
            page_no, 3, 4, PosRelationEnum.BOTTOM
        )
851
        with_footnotes = self.__tie_up_category_by_distance_v2(
icecraft's avatar
icecraft committed
852
            page_no, 3, CategoryId.ImageFootnote, PosRelationEnum.ALL
853
854
855
856
        )
        ret = []
        for v in with_captions:
            record = {
857
858
                'image_body': v['sub_bbox'],
                'image_caption_list': v['obj_bboxes'],
859
860
861
            }
            filter_idx = v['sub_idx']
            d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
862
            record['image_footnote_list'] = d['obj_bboxes']
863
864
865
866
            ret.append(record)
        return ret

    def get_tables_v2(self, page_no: int) -> list:
icecraft's avatar
icecraft committed
867
868
869
870
871
872
        with_captions = self.__tie_up_category_by_distance_v2(
            page_no, 5, 6, PosRelationEnum.UP
        )
        with_footnotes = self.__tie_up_category_by_distance_v2(
            page_no, 5, 7, PosRelationEnum.ALL
        )
873
874
875
        ret = []
        for v in with_captions:
            record = {
876
877
                'table_body': v['sub_bbox'],
                'table_caption_list': v['obj_bboxes'],
878
879
880
            }
            filter_idx = v['sub_idx']
            d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
881
            record['table_footnote_list'] = d['obj_bboxes']
882
883
884
            ret.append(record)
        return ret

blue's avatar
blue committed
885
    def get_imgs(self, page_no: int):
886
887
888
        with_captions, _ = self.__tie_up_category_by_distance(page_no, 3, 4)
        with_footnotes, _ = self.__tie_up_category_by_distance(
            page_no, 3, CategoryId.ImageFootnote
blue's avatar
blue committed
889
        )
890
891
892
893
894
895
896
897
898
        ret = []
        N, M = len(with_captions), len(with_footnotes)
        assert N == M
        for i in range(N):
            record = {
                'score': with_captions[i]['score'],
                'img_caption_bbox': with_captions[i].get('object_body', None),
                'img_body_bbox': with_captions[i]['subject_body'],
                'img_footnote_bbox': with_footnotes[i].get('object_body', None),
liukaiwen's avatar
liukaiwen committed
899
            }
900
901
902
903
904
905
906
907

            x0 = min(with_captions[i]['all'][0], with_footnotes[i]['all'][0])
            y0 = min(with_captions[i]['all'][1], with_footnotes[i]['all'][1])
            x1 = max(with_captions[i]['all'][2], with_footnotes[i]['all'][2])
            y1 = max(with_captions[i]['all'][3], with_footnotes[i]['all'][3])
            record['bbox'] = [x0, y0, x1, y1]
            ret.append(record)
        return ret
liukaiwen's avatar
liukaiwen committed
908
909

    def get_tables(
910
        self, page_no: int
liukaiwen's avatar
liukaiwen committed
911
912
913
914
915
916
917
918
    ) -> list:  # 3个坐标, caption, table主体,table-note
        with_captions, _ = self.__tie_up_category_by_distance(page_no, 5, 6)
        with_footnotes, _ = self.__tie_up_category_by_distance(page_no, 5, 7)
        ret = []
        N, M = len(with_captions), len(with_footnotes)
        assert N == M
        for i in range(N):
            record = {
919
920
921
922
                'score': with_captions[i]['score'],
                'table_caption_bbox': with_captions[i].get('object_body', None),
                'table_body_bbox': with_captions[i]['subject_body'],
                'table_footnote_bbox': with_footnotes[i].get('object_body', None),
liukaiwen's avatar
liukaiwen committed
923
924
            }

925
926
927
928
929
            x0 = min(with_captions[i]['all'][0], with_footnotes[i]['all'][0])
            y0 = min(with_captions[i]['all'][1], with_footnotes[i]['all'][1])
            x1 = max(with_captions[i]['all'][2], with_footnotes[i]['all'][2])
            y1 = max(with_captions[i]['all'][3], with_footnotes[i]['all'][3])
            record['bbox'] = [x0, y0, x1, y1]
liukaiwen's avatar
liukaiwen committed
930
931
932
933
            ret.append(record)
        return ret

    def get_equations(self, page_no: int) -> list:  # 有坐标,也有字
934
        inline_equations = self.__get_blocks_by_type(
935
            ModelBlockTypeEnum.EMBEDDING.value, page_no, ['latex']
936
937
        )
        interline_equations = self.__get_blocks_by_type(
938
            ModelBlockTypeEnum.ISOLATED.value, page_no, ['latex']
939
940
941
942
        )
        interline_equations_blocks = self.__get_blocks_by_type(
            ModelBlockTypeEnum.ISOLATE_FORMULA.value, page_no
        )
liukaiwen's avatar
liukaiwen committed
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
        return inline_equations, interline_equations, interline_equations_blocks

    def get_discarded(self, page_no: int) -> list:  # 自研模型,只有坐标
        blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.ABANDON.value, page_no)
        return blocks

    def get_text_blocks(self, page_no: int) -> list:  # 自研模型搞的,只有坐标,没有字
        blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.PLAIN_TEXT.value, page_no)
        return blocks

    def get_title_blocks(self, page_no: int) -> list:  # 自研模型,只有坐标,没字
        blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.TITLE.value, page_no)
        return blocks

    def get_ocr_text(self, page_no: int) -> list:  # paddle 搞的,有字也有坐标
        text_spans = []
        model_page_info = self.__model_list[page_no]
960
        layout_dets = model_page_info['layout_dets']
liukaiwen's avatar
liukaiwen committed
961
        for layout_det in layout_dets:
962
            if layout_det['category_id'] == '15':
liukaiwen's avatar
liukaiwen committed
963
                span = {
964
965
                    'bbox': layout_det['bbox'],
                    'content': layout_det['text'],
liukaiwen's avatar
liukaiwen committed
966
967
968
969
970
                }
                text_spans.append(span)
        return text_spans

    def get_all_spans(self, page_no: int) -> list:
971

972
973
974
975
976
977
        def remove_duplicate_spans(spans):
            new_spans = []
            for span in spans:
                if not any(span == existing_span for existing_span in new_spans):
                    new_spans.append(span)
            return new_spans
blue's avatar
blue committed
978

liukaiwen's avatar
liukaiwen committed
979
980
        all_spans = []
        model_page_info = self.__model_list[page_no]
981
        layout_dets = model_page_info['layout_dets']
liukaiwen's avatar
liukaiwen committed
982
983
984
        allow_category_id_list = [3, 5, 13, 14, 15]
        """当成span拼接的"""
        #  3: 'image', # 图片
985
        #  5: 'table',       # 表格
liukaiwen's avatar
liukaiwen committed
986
987
988
989
        #  13: 'inline_equation',     # 行内公式
        #  14: 'interline_equation',      # 行间公式
        #  15: 'text',      # ocr识别文本
        for layout_det in layout_dets:
990
            category_id = layout_det['category_id']
liukaiwen's avatar
liukaiwen committed
991
            if category_id in allow_category_id_list:
992
                span = {'bbox': layout_det['bbox'], 'score': layout_det['score']}
liukaiwen's avatar
liukaiwen committed
993
                if category_id == 3:
994
                    span['type'] = ContentType.Image
liukaiwen's avatar
liukaiwen committed
995
                elif category_id == 5:
996
                    # 获取table模型结果
997
998
                    latex = layout_det.get('latex', None)
                    html = layout_det.get('html', None)
999
                    if latex:
1000
                        span['latex'] = latex
1001
                    elif html:
1002
1003
                        span['html'] = html
                    span['type'] = ContentType.Table
liukaiwen's avatar
liukaiwen committed
1004
                elif category_id == 13:
1005
1006
                    span['content'] = layout_det['latex']
                    span['type'] = ContentType.InlineEquation
liukaiwen's avatar
liukaiwen committed
1007
                elif category_id == 14:
1008
1009
                    span['content'] = layout_det['latex']
                    span['type'] = ContentType.InterlineEquation
liukaiwen's avatar
liukaiwen committed
1010
                elif category_id == 15:
1011
1012
                    span['content'] = layout_det['text']
                    span['type'] = ContentType.Text
liukaiwen's avatar
liukaiwen committed
1013
                all_spans.append(span)
1014
        return remove_duplicate_spans(all_spans)
liukaiwen's avatar
liukaiwen committed
1015
1016
1017

    def get_page_size(self, page_no: int):  # 获取页面宽高
        # 获取当前页的page对象
1018
        page = self.__docs.get_page(page_no).get_page_info()
liukaiwen's avatar
liukaiwen committed
1019
        # 获取当前页的宽高
1020
1021
        page_w = page.w
        page_h = page.h
liukaiwen's avatar
liukaiwen committed
1022
1023
        return page_w, page_h

1024
1025
1026
    def __get_blocks_by_type(
        self, type: int, page_no: int, extra_col: list[str] = []
    ) -> list:
liukaiwen's avatar
liukaiwen committed
1027
1028
        blocks = []
        for page_dict in self.__model_list:
1029
1030
1031
            layout_dets = page_dict.get('layout_dets', [])
            page_info = page_dict.get('page_info', {})
            page_number = page_info.get('page_no', -1)
liukaiwen's avatar
liukaiwen committed
1032
1033
1034
            if page_no != page_number:
                continue
            for item in layout_dets:
1035
1036
                category_id = item.get('category_id', -1)
                bbox = item.get('bbox', None)
liukaiwen's avatar
liukaiwen committed
1037

liukaiwen's avatar
liukaiwen committed
1038
                if category_id == type:
1039
                    block = {
1040
1041
                        'bbox': bbox,
                        'score': item.get('score'),
1042
                    }
liukaiwen's avatar
liukaiwen committed
1043
1044
1045
1046
1047
                    for col in extra_col:
                        block[col] = item.get(col, None)
                    blocks.append(block)
        return blocks

许瑞's avatar
许瑞 committed
1048
1049
1050
    def get_model_list(self, page_no):
        return self.__model_list[page_no]

1051

1052
if __name__ == '__main__':
1053
    drw = FileBasedDataReader(r'D:/project/20231108code-clean')
liukaiwen's avatar
liukaiwen committed
1054
    if 0:
1055
1056
        pdf_file_path = r'linshixuqiu\19983-00.pdf'
        model_file_path = r'linshixuqiu\19983-00_new.json'
1057
1058
        pdf_bytes = drw.read(pdf_file_path)
        model_json_txt = drw.read(model_file_path).decode()
liukaiwen's avatar
liukaiwen committed
1059
        model_list = json.loads(model_json_txt)
1060
1061
        write_path = r'D:\project\20231108code-clean\linshixuqiu\19983-00'
        img_bucket_path = 'imgs'
1062
        img_writer = FileBasedDataWriter(join_path(write_path, img_bucket_path))
1063
        pdf_docs = fitz.open('pdf', pdf_bytes)
liukaiwen's avatar
liukaiwen committed
1064
1065
1066
        magic_model = MagicModel(model_list, pdf_docs)

    if 1:
1067
1068
        from magic_pdf.data.dataset import PymuDocDataset

liukaiwen's avatar
liukaiwen committed
1069
        model_list = json.loads(
1070
            drw.read('/opt/data/pdf/20240418/j.chroma.2009.03.042.json')
liukaiwen's avatar
liukaiwen committed
1071
        )
1072
1073
1074
        pdf_bytes = drw.read('/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf')

        magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
liukaiwen's avatar
liukaiwen committed
1075
1076
        for i in range(7):
            print(magic_model.get_imgs(i))