span_block_fix.py 8.59 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
# Copyright (c) Opendatalab. All rights reserved.
from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
from mineru.utils.enum_class import BlockType, ContentType
from mineru.utils.ocr_utils import _is_overlaps_y_exceeds_threshold, _is_overlaps_x_exceeds_threshold

VERTICAL_SPAN_HEIGHT_TO_WIDTH_RATIO_THRESHOLD = 2
VERTICAL_SPAN_IN_BLOCK_THRESHOLD = 0.8

def fill_spans_in_blocks(blocks, spans, radio):
    """将allspans中的span按位置关系,放入blocks中."""
    block_with_spans = []
    for block in blocks:
        block_type = block[7]
        block_bbox = block[0:4]
        block_dict = {
            'type': block_type,
            'bbox': block_bbox,
        }
        if block_type in [
            BlockType.IMAGE_BODY, BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE,
            BlockType.TABLE_BODY, BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE
        ]:
            block_dict['group_id'] = block[-1]
        block_spans = []
        for span in spans:
            temp_radio = radio
            span_bbox = span['bbox']
            if span['type'] in [ContentType.IMAGE, ContentType.TABLE]:
                temp_radio = 0.9
            if calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > temp_radio and span_block_type_compatible(span['type'], block_type):
                block_spans.append(span)

        block_dict['spans'] = block_spans
        block_with_spans.append(block_dict)

        # 从spans删除已经放入block_spans中的span
        if len(block_spans) > 0:
            for span in block_spans:
                spans.remove(span)

    return block_with_spans, spans


def span_block_type_compatible(span_type, block_type):
    if span_type in [ContentType.TEXT, ContentType.INLINE_EQUATION]:
        return block_type in [
            BlockType.TEXT,
            BlockType.TITLE,
            BlockType.IMAGE_CAPTION,
            BlockType.IMAGE_FOOTNOTE,
            BlockType.TABLE_CAPTION,
            BlockType.TABLE_FOOTNOTE,
            BlockType.DISCARDED
        ]
    elif span_type == ContentType.INTERLINE_EQUATION:
        return block_type in [BlockType.INTERLINE_EQUATION, BlockType.TEXT]
    elif span_type == ContentType.IMAGE:
        return block_type in [BlockType.IMAGE_BODY]
    elif span_type == ContentType.TABLE:
        return block_type in [BlockType.TABLE_BODY]
    else:
        return False


def fix_discarded_block(discarded_block_with_spans):
    fix_discarded_blocks = []
    for block in discarded_block_with_spans:
        block = fix_text_block(block)
        fix_discarded_blocks.append(block)
    return fix_discarded_blocks


def fix_text_block(block):
    # 文本block中的公式span都应该转换成行内type
    for span in block['spans']:
        if span['type'] == ContentType.INTERLINE_EQUATION:
            span['type'] = ContentType.INLINE_EQUATION

    # 假设block中的span超过80%的数量高度是宽度的两倍以上,则认为是纵向文本块
    vertical_span_count = sum(
        1 for span in block['spans']
        if (span['bbox'][3] - span['bbox'][1]) / (span['bbox'][2] - span['bbox'][0]) > VERTICAL_SPAN_HEIGHT_TO_WIDTH_RATIO_THRESHOLD
    )
    total_span_count = len(block['spans'])
    if total_span_count == 0:
        vertical_ratio = 0
    else:
        vertical_ratio = vertical_span_count / total_span_count

    if vertical_ratio > VERTICAL_SPAN_IN_BLOCK_THRESHOLD:
        # 如果是纵向文本块,则按纵向lines处理
        block_lines = merge_spans_to_vertical_line(block['spans'])
        sort_block_lines = vertical_line_sort_spans_from_top_to_bottom(block_lines)
    else:
        block_lines = merge_spans_to_line(block['spans'])
        sort_block_lines = line_sort_spans_by_left_to_right(block_lines)

    block['lines'] = sort_block_lines
    del block['spans']
    return block


def merge_spans_to_line(spans, threshold=0.6):
    if len(spans) == 0:
        return []
    else:
        # 按照y0坐标排序
        spans.sort(key=lambda span: span['bbox'][1])

        lines = []
        current_line = [spans[0]]
        for span in spans[1:]:
            # 如果当前的span类型为"interline_equation" 或者 当前行中已经有"interline_equation"
            # image和table类型,同上
            if span['type'] in [
                    ContentType.INTERLINE_EQUATION, ContentType.IMAGE,
                    ContentType.TABLE
            ] or any(s['type'] in [
                    ContentType.INTERLINE_EQUATION, ContentType.IMAGE,
                    ContentType.TABLE
            ] for s in current_line):
                # 则开始新行
                lines.append(current_line)
                current_line = [span]
                continue

            # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
            if _is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold):
                current_line.append(span)
            else:
                # 否则,开始新行
                lines.append(current_line)
                current_line = [span]

        # 添加最后一行
        if current_line:
            lines.append(current_line)

        return lines


def merge_spans_to_vertical_line(spans, threshold=0.6):
    """将纵向文本的spans合并成纵向lines(从右向左阅读)"""
    if len(spans) == 0:
        return []
    else:
        # 按照x2坐标从大到小排序(从右向左)
        spans.sort(key=lambda span: span['bbox'][2], reverse=True)

        vertical_lines = []
        current_line = [spans[0]]

        for span in spans[1:]:
            # 特殊类型元素单独成列
            if span['type'] in [
                ContentType.INTERLINE_EQUATION, ContentType.IMAGE,
                ContentType.TABLE
            ] or any(s['type'] in [
                ContentType.INTERLINE_EQUATION, ContentType.IMAGE,
                ContentType.TABLE
            ] for s in current_line):
                vertical_lines.append(current_line)
                current_line = [span]
                continue

            # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
            if _is_overlaps_x_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold):
                current_line.append(span)
            else:
                vertical_lines.append(current_line)
                current_line = [span]

        # 添加最后一列
        if current_line:
            vertical_lines.append(current_line)

        return vertical_lines


# 将每一个line中的span从左到右排序
def line_sort_spans_by_left_to_right(lines):
    line_objects = []
    for line in lines:
        #  按照x0坐标排序
        line.sort(key=lambda span: span['bbox'][0])
        line_bbox = [
            min(span['bbox'][0] for span in line),  # x0
            min(span['bbox'][1] for span in line),  # y0
            max(span['bbox'][2] for span in line),  # x1
            max(span['bbox'][3] for span in line),  # y1
        ]
        line_objects.append({
            'bbox': line_bbox,
            'spans': line,
        })
    return line_objects


def vertical_line_sort_spans_from_top_to_bottom(vertical_lines):
    line_objects = []
    for line in vertical_lines:
        # 按照y0坐标排序(从上到下)
        line.sort(key=lambda span: span['bbox'][1])

        # 计算整个列的边界框
        line_bbox = [
            min(span['bbox'][0] for span in line),  # x0
            min(span['bbox'][1] for span in line),  # y0
            max(span['bbox'][2] for span in line),  # x1
            max(span['bbox'][3] for span in line),  # y1
        ]

        # 组装结果
        line_objects.append({
            'bbox': line_bbox,
            'spans': line,
        })
    return line_objects


def fix_block_spans(block_with_spans):
    fix_blocks = []
    for block in block_with_spans:
        block_type = block['type']

        if block_type in [BlockType.TEXT, BlockType.TITLE,
                          BlockType.IMAGE_CAPTION, BlockType.IMAGE_CAPTION,
                          BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE
                          ]:
            block = fix_text_block(block)
        elif block_type in [BlockType.INTERLINE_EQUATION, BlockType.IMAGE_BODY, BlockType.TABLE_BODY]:
            block = fix_interline_block(block)
        else:
            continue
        fix_blocks.append(block)
    return fix_blocks


def fix_interline_block(block):
    block_lines = merge_spans_to_line(block['spans'])
    sort_block_lines = line_sort_spans_by_left_to_right(block_lines)
    block['lines'] = sort_block_lines
    del block['spans']
    return block