para_split.py.bak 32.4 KB
Newer Older
xuchao's avatar
xuchao committed
1
2
import numpy as np
from loguru import logger
3
from sklearn.cluster import DBSCAN
xuchao's avatar
xuchao committed
4

5
6
7
from magic_pdf.config.ocr_content_type import ContentType
from magic_pdf.libs.boxbase import \
    _is_in_or_part_overlap_with_area_ratio as is_in_layout
xuchao's avatar
xuchao committed
8

9
LINE_STOP_FLAG = ['.', '!', '?', '。', '!', '?', ':', ':', ')', ')', ';']
xuchao's avatar
xuchao committed
10
11
INLINE_EQUATION = ContentType.InlineEquation
INTERLINE_EQUATION = ContentType.InterlineEquation
kernel.h@qq.com's avatar
kernel.h@qq.com committed
12
TEXT = ContentType.Text
xuchao's avatar
xuchao committed
13

xuchao's avatar
xuchao committed
14
15
16

def __get_span_text(span):
    c = span.get('content', '')
17
    if len(c) == 0:
赵小蒙's avatar
赵小蒙 committed
18
        c = span.get('image_path', '')
19

xuchao's avatar
xuchao committed
20
    return c
21

22

kernel.h@qq.com's avatar
kernel.h@qq.com committed
23
def __detect_list_lines(lines, new_layout_bboxes, lang):
24
25
    """探测是否包含了列表,并且把列表的行分开.

26
27
    这样的段落特点是,顶格字母大写/数字,紧跟着几行缩进的。缩进的行首字母含小写的。
    """
28

29
30
31
32
33
    def find_repeating_patterns(lst):
        indices = []
        ones_indices = []
        i = 0
        while i < len(lst) - 1:  # 确保余下元素至少有2个
34
            if lst[i] == 1 and lst[i + 1] in [2, 3]:  # 额外检查以防止连续出现的1
35
36
37
38
39
40
                start = i
                ones_in_this_interval = [i]
                i += 1
                while i < len(lst) and lst[i] in [2, 3]:
                    i += 1
                # 验证下一个序列是否符合条件
41
42
43
44
45
46
                if (
                    i < len(lst) - 1
                    and lst[i] == 1
                    and lst[i + 1] in [2, 3]
                    and lst[i - 1] in [2, 3]
                ):
47
48
49
50
51
52
53
54
55
56
57
                    while i < len(lst) and lst[i] in [1, 2, 3]:
                        if lst[i] == 1:
                            ones_in_this_interval.append(i)
                        i += 1
                    indices.append((start, i - 1))
                    ones_indices.append(ones_in_this_interval)
                else:
                    i += 1
            else:
                i += 1
        return indices, ones_indices
58

59
    """===================="""
60

61
62
63
    def split_indices(slen, index_array):
        result = []
        last_end = 0
64

65
66
67
68
69
70
71
72
73
74
75
76
77
        for start, end in sorted(index_array):
            if start > last_end:
                # 前一个区间结束到下一个区间开始之间的部分标记为"text"
                result.append(('text', last_end, start - 1))
            # 区间内标记为"list"
            result.append(('list', start, end))
            last_end = end + 1

        if last_end < slen:
            # 如果最后一个区间结束后还有剩余的字符串,将其标记为"text"
            result.append(('text', last_end, slen - 1))

        return result
78

79
80
    """===================="""

81
    if lang != 'en':
82
83
84
85
86
87
88
89
90
91
92
        return lines, None
    else:
        total_lines = len(lines)
        line_fea_encode = []
        """
        对每一行进行特征编码,编码规则如下:
        1. 如果行顶格,且大写字母开头或者数字开头,编码为1
        2. 如果顶格,其他非大写开头编码为4
        3. 如果非顶格,首字符大写,编码为2
        4. 如果非顶格,首字符非大写编码为3
        """
93
        for l in lines:  # noqa: E741
94
95
96
97
98
99
100
101
102
103
104
105
            first_char = __get_span_text(l['spans'][0])[0]
            layout_left = __find_layout_bbox_by_line(l['bbox'], new_layout_bboxes)[0]
            if l['bbox'][0] == layout_left:
                if first_char.isupper() or first_char.isdigit():
                    line_fea_encode.append(1)
                else:
                    line_fea_encode.append(4)
            else:
                if first_char.isupper():
                    line_fea_encode.append(2)
                else:
                    line_fea_encode.append(3)
106

107
        # 然后根据编码进行分段, 选出来 1,2,3连续出现至少2次的行,认为是列表。
108
109
110
111
112

        list_indice, list_start_idx = find_repeating_patterns(line_fea_encode)
        if len(list_indice) > 0:
            logger.info(f'发现了列表,列表行数:{list_indice}, {list_start_idx}')

113
        # TODO check一下这个特列表里缩进的行左侧是不是对齐的。
114

115
        for start, end in list_indice:
116
117
            for i in range(start, end + 1):
                if i > 0:
118
                    if line_fea_encode[i] == 4:
119
                        logger.info(f'列表行的第{i}行不是顶格的')
120
121
                        break
            else:
122
123
                logger.info(f'列表行的第{start}到第{end}行是列表')

124
        return split_indices(total_lines, list_indice), list_start_idx
125

xuchao's avatar
xuchao committed
126
127

def __valign_lines(blocks, layout_bboxes):
128
129
130
131
    """在一个layoutbox内对齐行的左侧和右侧。 扫描行的左侧和右侧,如果x0,
    x1差距不超过一个阈值,就强行对齐到所处layout的左右两侧(和layout有一段距离)。
    3是个经验值,TODO,计算得来,可以设置为1.5个正文字符。"""

xuchao's avatar
xuchao committed
132
133
    min_distance = 3
    min_sample = 2
xuchao's avatar
xuchao committed
134
    new_layout_bboxes = []
135

xuchao's avatar
xuchao committed
136
    for layout_box in layout_bboxes:
137
138
139
140
        blocks_in_layoutbox = [
            b for b in blocks if is_in_layout(b['bbox'], layout_box['layout_bbox'])
        ]
        if len(blocks_in_layoutbox) == 0:
xuchao's avatar
xuchao committed
141
            continue
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

        x0_lst = np.array(
            [
                [line['bbox'][0], 0]
                for block in blocks_in_layoutbox
                for line in block['lines']
            ]
        )
        x1_lst = np.array(
            [
                [line['bbox'][2], 0]
                for block in blocks_in_layoutbox
                for line in block['lines']
            ]
        )
xuchao's avatar
xuchao committed
157
158
159
160
        x0_clusters = DBSCAN(eps=min_distance, min_samples=min_sample).fit(x0_lst)
        x1_clusters = DBSCAN(eps=min_distance, min_samples=min_sample).fit(x1_lst)
        x0_uniq_label = np.unique(x0_clusters.labels_)
        x1_uniq_label = np.unique(x1_clusters.labels_)
161
162

        x0_2_new_val = {}  # 存储旧值对应的新值映射
xuchao's avatar
xuchao committed
163
164
        x1_2_new_val = {}
        for label in x0_uniq_label:
165
            if label == -1:
xuchao's avatar
xuchao committed
166
                continue
167
168
169
            x0_index_of_label = np.where(x0_clusters.labels_ == label)
            x0_raw_val = x0_lst[x0_index_of_label][:, 0]
            x0_new_val = np.min(x0_lst[x0_index_of_label][:, 0])
xuchao's avatar
xuchao committed
170
171
            x0_2_new_val.update({idx: x0_new_val for idx in x0_raw_val})
        for label in x1_uniq_label:
172
            if label == -1:
xuchao's avatar
xuchao committed
173
                continue
174
175
176
            x1_index_of_label = np.where(x1_clusters.labels_ == label)
            x1_raw_val = x1_lst[x1_index_of_label][:, 0]
            x1_new_val = np.max(x1_lst[x1_index_of_label][:, 0])
xuchao's avatar
xuchao committed
177
            x1_2_new_val.update({idx: x1_new_val for idx in x1_raw_val})
178

xuchao's avatar
xuchao committed
179
180
181
182
183
184
185
186
187
        for block in blocks_in_layoutbox:
            for line in block['lines']:
                x0, x1 = line['bbox'][0], line['bbox'][2]
                if x0 in x0_2_new_val:
                    line['bbox'][0] = int(x0_2_new_val[x0])

                if x1 in x1_2_new_val:
                    line['bbox'][2] = int(x1_2_new_val[x1])
            # 其余对不齐的保持不动
188

xuchao's avatar
xuchao committed
189
190
        # 由于修改了block里的line长度,现在需要重新计算block的bbox
        for block in blocks_in_layoutbox:
191
192
193
194
195
196
197
            block['bbox'] = [
                min([line['bbox'][0] for line in block['lines']]),
                min([line['bbox'][1] for line in block['lines']]),
                max([line['bbox'][2] for line in block['lines']]),
                max([line['bbox'][3] for line in block['lines']]),
            ]

xuchao's avatar
xuchao committed
198
199
200
201
202
203
        """新计算layout的bbox,因为block的bbox变了。"""
        layout_x0 = min([block['bbox'][0] for block in blocks_in_layoutbox])
        layout_y0 = min([block['bbox'][1] for block in blocks_in_layoutbox])
        layout_x1 = max([block['bbox'][2] for block in blocks_in_layoutbox])
        layout_y1 = max([block['bbox'][3] for block in blocks_in_layoutbox])
        new_layout_bboxes.append([layout_x0, layout_y0, layout_x1, layout_y1])
204

xuchao's avatar
xuchao committed
205
    return new_layout_bboxes
xuchao's avatar
xuchao committed
206
207


208
def __align_text_in_layout(blocks, layout_bboxes):
209
    """由于ocr出来的line,有时候会在前后有一段空白,这个时候需要对文本进行对齐,超出的部分被layout左右侧截断。"""
210
211
212
    for layout in layout_bboxes:
        lb = layout['layout_bbox']
        blocks_in_layoutbox = [b for b in blocks if is_in_layout(b['bbox'], lb)]
213
        if len(blocks_in_layoutbox) == 0:
214
            continue
215

216
217
218
219
220
221
222
        for block in blocks_in_layoutbox:
            for line in block['lines']:
                x0, x1 = line['bbox'][0], line['bbox'][2]
                if x0 < lb[0]:
                    line['bbox'][0] = lb[0]
                if x1 > lb[2]:
                    line['bbox'][2] = lb[2]
223
224


xuchao's avatar
xuchao committed
225
def __common_pre_proc(blocks, layout_bboxes):
226
227
    """不分语言的,对文本进行预处理."""
    # __add_line_period(blocks, layout_bboxes)
228
    __align_text_in_layout(blocks, layout_bboxes)
xuchao's avatar
xuchao committed
229
    aligned_layout_bboxes = __valign_lines(blocks, layout_bboxes)
230

xuchao's avatar
xuchao committed
231
    return aligned_layout_bboxes
xuchao's avatar
xuchao committed
232

233

xuchao's avatar
xuchao committed
234
def __pre_proc_zh_blocks(blocks, layout_bboxes):
235
    """对中文文本进行分段预处理."""
xuchao's avatar
xuchao committed
236
237
238
239
    pass


def __pre_proc_en_blocks(blocks, layout_bboxes):
240
    """对英文文本进行分段预处理."""
xuchao's avatar
xuchao committed
241
242
243
    pass


244
245
def __group_line_by_layout(blocks, layout_bboxes, lang='en'):
    """每个layout内的行进行聚合."""
xuchao's avatar
xuchao committed
246
247
    # 因为只是一个block一行目前, 一个block就是一个段落
    lines_group = []
248

xuchao's avatar
xuchao committed
249
    for lyout in layout_bboxes:
250
251
252
253
254
255
        lines = [
            line
            for block in blocks
            if is_in_layout(block['bbox'], lyout['layout_bbox'])
            for line in block['lines']
        ]
xuchao's avatar
xuchao committed
256
257
258
259
        lines_group.append(lines)

    return lines_group

260
261

def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang='en', char_avg_len=10):
xuchao's avatar
xuchao committed
262
    """
263
    lines_group 进行行分段——layout内部进行分段。lines_group内每个元素是一个Layoutbox内的所有行。
xuchao's avatar
xuchao committed
264
265
266
    1. 先计算每个group的左右边界。
    2. 然后根据行末尾特征进行分段。
        末尾特征:以句号等结束符结尾。并且距离右侧边界有一定距离。
267
        且下一行开头不留空白。
268

xuchao's avatar
xuchao committed
269
    """
270
    list_info = []  # 这个layout最后是不是列表,记录每一个layout里是不是列表开头,列表结尾
271
    layout_paras = []
xuchao's avatar
xuchao committed
272
    right_tail_distance = 1.5 * char_avg_len
273

xuchao's avatar
xuchao committed
274
    for lines in lines_group:
275
        paras = []
276
        total_lines = len(lines)
277
278
279
        if total_lines == 0:
            continue  # 0行无需处理
        if total_lines == 1:  # 1行无法分段。
280
281
            layout_paras.append([lines])
            list_info.append([False, False])
xuchao's avatar
xuchao committed
282
            continue
283

284
285
286
287
288
289
290
291
        """在进入到真正的分段之前,要对文字块从统计维度进行对齐方式的探测,
            对齐方式分为以下:
            1. 左对齐的文本块(特点是左侧顶格,或者左侧不顶格但是右侧顶格的行数大于非顶格的行数,顶格的首字母有大写也有小写)
                1) 右侧对齐的行,单独成一段
                2) 中间对齐的行,按照字体/行高聚合成一段
            2. 左对齐的列表块(其特点是左侧顶格的行数小于等于非顶格的行数,非定格首字母会有小写,顶格90%是大写。并且左侧顶格行数大于1,大于1是为了这种模式连续出现才能称之为列表)
                这样的文本块,顶格的为一个段落开头,紧随其后非顶格的行属于这个段落。
        """
292
293
294
295

        text_segments, list_start_line = __detect_list_lines(
            lines, new_layout_bbox, lang
        )
296
        """根据list_range,把lines分成几个部分
297

298
        """
299

300
        layout_right = __find_layout_bbox_by_line(lines[0]['bbox'], new_layout_bbox)[2]
kernel.h@qq.com's avatar
merge  
kernel.h@qq.com committed
301
        layout_left = __find_layout_bbox_by_line(lines[0]['bbox'], new_layout_bbox)[0]
302
303
304
305
306
        para = []  # 元素是line
        layout_list_info = [
            False,
            False,
        ]  # 这个layout最后是不是列表,记录每一个layout里是不是列表开头,列表结尾
307
308
        for content_type, start, end in text_segments:
            if content_type == 'list':
309
                for i, line in enumerate(lines[start : end + 1]):
310
                    line_x0 = line['bbox'][0]
311
312
                    if line_x0 == layout_left:  # 列表开头
                        if len(para) > 0:
313
314
315
316
317
                            paras.append(para)
                            para = []
                        para.append(line)
                    else:
                        para.append(line)
318
                if len(para) > 0:
xuchao's avatar
xuchao committed
319
320
                    paras.append(para)
                    para = []
321
                if start == 0:
322
                    layout_list_info[0] = True
323
                if end == total_lines - 1:
324
                    layout_list_info[1] = True
325
326
            else:  # 是普通文本
                for i, line in enumerate(lines[start : end + 1]):
kernel.h@qq.com's avatar
update  
kernel.h@qq.com committed
327
                    # 如果i有下一行,那么就要根据下一行位置综合判断是否要分段。如果i之后没有行,那么只需要判断i行自己的结尾特征。
328
                    cur_line_type = line['spans'][-1]['type']
329
330
                    next_line = lines[i + 1] if i < total_lines - 1 else None

331
332
333
334
335
                    if cur_line_type in [TEXT, INLINE_EQUATION]:
                        if line['bbox'][2] < layout_right - right_tail_distance:
                            para.append(line)
                            paras.append(para)
                            para = []
336
337
338
339
340
                        elif (
                            line['bbox'][2] >= layout_right - right_tail_distance
                            and next_line
                            and next_line['bbox'][0] == layout_left
                        ):  # 现在这行到了行尾沾满,下一行存在且顶格。
341
                            para.append(line)
342
                        else:
343
344
345
                            para.append(line)
                            paras.append(para)
                            para = []
346
347
                    else:  # 其他,图片、表格、行间公式,各自占一段
                        if len(para) > 0:  # 先把之前的段落加入到结果中
348
349
                            paras.append(para)
                            para = []
350
351
352
                        paras.append(
                            [line]
                        )  # 再把当前行加入到结果中。当前行为行间公式、图、表等。
353
                        para = []
354
355

                if len(para) > 0:
xuchao's avatar
xuchao committed
356
357
                    paras.append(para)
                    para = []
358

359
360
361
        list_info.append(layout_list_info)
        layout_paras.append(paras)
        paras = []
362

363
364
    return layout_paras, list_info

365
366
367
368
369
370
371
372
373
374

def __connect_list_inter_layout(
    layout_paras, new_layout_bbox, layout_list_info, page_num, lang
):
    """如果上个layout的最后一个段落是列表,下一个layout的第一个段落也是列表,那么将他们连接起来。 TODO
    因为没有区分列表和段落,所以这个方法暂时不实现。
    根据layout_list_info判断是不是列表。,下个layout的第一个段如果不是列表,那么看他们是否有几行都有相同的缩进。"""
    if (
        len(layout_paras) == 0 or len(layout_list_info) == 0
    ):  # 0的时候最后的return 会出错
kernel.h@qq.com's avatar
kernel.h@qq.com committed
375
        return layout_paras, [False, False]
376

kernel.h@qq.com's avatar
kernel.h@qq.com committed
377
    for i in range(1, len(layout_paras)):
378
        pre_layout_list_info = layout_list_info[i - 1]
kernel.h@qq.com's avatar
kernel.h@qq.com committed
379
        next_layout_list_info = layout_list_info[i]
380
        pre_last_para = layout_paras[i - 1][-1]
kernel.h@qq.com's avatar
kernel.h@qq.com committed
381
        next_paras = layout_paras[i]
382
383
384
385
386

        if (
            pre_layout_list_info[1] and not next_layout_list_info[0]
        ):  # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进
            logger.info(f'连接page {page_num} 内的list')
kernel.h@qq.com's avatar
kernel.h@qq.com committed
387
388
389
390
            # 向layout_paras[i] 寻找开头具有相同缩进的连续的行
            may_list_lines = []
            for j in range(len(next_paras)):
                line = next_paras[j]
391
392
393
394
395
396
397
                if len(line) == 1:  # 只可能是一行,多行情况再需要分析了
                    if (
                        line[0]['bbox'][0]
                        > __find_layout_bbox_by_line(line[0]['bbox'], new_layout_bbox)[
                            0
                        ]
                    ):
kernel.h@qq.com's avatar
kernel.h@qq.com committed
398
399
400
401
402
403
                        may_list_lines.append(line[0])
                    else:
                        break
                else:
                    break
            # 如果这些行的缩进是相等的,那么连到上一个layout的最后一个段落上。
404
405
406
407
            if (
                len(may_list_lines) > 0
                and len(set([x['bbox'][0] for x in may_list_lines])) == 1
            ):
kernel.h@qq.com's avatar
kernel.h@qq.com committed
408
                pre_last_para.extend(may_list_lines)
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
                layout_paras[i] = layout_paras[i][len(may_list_lines) :]

    return layout_paras, [
        layout_list_info[0][0],
        layout_list_info[-1][1],
    ]  # 同时还返回了这个页面级别的开头、结尾是不是列表的信息


def __connect_list_inter_page(
    pre_page_paras,
    next_page_paras,
    pre_page_layout_bbox,
    next_page_layout_bbox,
    pre_page_list_info,
    next_page_list_info,
    page_num,
    lang,
):
    """如果上个layout的最后一个段落是列表,下一个layout的第一个段落也是列表,那么将他们连接起来。 TODO
    因为没有区分列表和段落,所以这个方法暂时不实现。
    根据layout_list_info判断是不是列表。,下个layout的第一个段如果不是列表,那么看他们是否有几行都有相同的缩进。"""
    if (
        len(pre_page_paras) == 0 or len(next_page_paras) == 0
    ):  # 0的时候最后的return 会出错
kernel.h@qq.com's avatar
kernel.h@qq.com committed
433
        return False
434
435
436
437
438

    if (
        pre_page_list_info[1] and not next_page_list_info[0]
    ):  # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进
        logger.info(f'连接page {page_num} 内的list')
kernel.h@qq.com's avatar
kernel.h@qq.com committed
439
440
441
442
        # 向layout_paras[i] 寻找开头具有相同缩进的连续的行
        may_list_lines = []
        for j in range(len(next_page_paras[0])):
            line = next_page_paras[0][j]
443
444
445
446
447
448
449
            if len(line) == 1:  # 只可能是一行,多行情况再需要分析了
                if (
                    line[0]['bbox'][0]
                    > __find_layout_bbox_by_line(
                        line[0]['bbox'], next_page_layout_bbox
                    )[0]
                ):
kernel.h@qq.com's avatar
kernel.h@qq.com committed
450
451
452
453
454
455
                    may_list_lines.append(line[0])
                else:
                    break
            else:
                break
        # 如果这些行的缩进是相等的,那么连到上一个layout的最后一个段落上。
456
457
458
459
        if (
            len(may_list_lines) > 0
            and len(set([x['bbox'][0] for x in may_list_lines])) == 1
        ):
kernel.h@qq.com's avatar
kernel.h@qq.com committed
460
            pre_page_paras[-1].append(may_list_lines)
461
            next_page_paras[0] = next_page_paras[0][len(may_list_lines) :]
kernel.h@qq.com's avatar
kernel.h@qq.com committed
462
            return True
463

kernel.h@qq.com's avatar
kernel.h@qq.com committed
464
    return False
xuchao's avatar
xuchao committed
465

xuchao's avatar
xuchao committed
466
467

def __find_layout_bbox_by_line(line_bbox, layout_bboxes):
468
    """根据line找到所在的layout."""
xuchao's avatar
xuchao committed
469
    for layout in layout_bboxes:
470
        if is_in_layout(line_bbox, layout):
xuchao's avatar
xuchao committed
471
472
473
474
            return layout
    return None


kernel.h@qq.com's avatar
kernel.h@qq.com committed
475
def __connect_para_inter_layoutbox(layout_paras, new_layout_bbox, lang):
xuchao's avatar
xuchao committed
476
477
478
479
480
481
482
483
484
    """
    layout之间进行分段。
    主要是计算前一个layOut的最后一行和后一个layout的第一行是否可以连接。
    连接的条件需要同时满足:
    1. 上一个layout的最后一行沾满整个行。并且没有结尾符号。
    2. 下一行开头不留空白。

    """
    connected_layout_paras = []
485
    if len(layout_paras) == 0:
kernel.h@qq.com's avatar
kernel.h@qq.com committed
486
        return connected_layout_paras
487

488
489
    connected_layout_paras.append(layout_paras[0])
    for i in range(1, len(layout_paras)):
kernel.h@qq.com's avatar
kernel.h@qq.com committed
490
        try:
491
492
493
            if (
                len(layout_paras[i]) == 0 or len(layout_paras[i - 1]) == 0
            ):  # TODO 考虑连接问题,
kernel.h@qq.com's avatar
kernel.h@qq.com committed
494
                continue
495
            pre_last_line = layout_paras[i - 1][-1][-1]
kernel.h@qq.com's avatar
kernel.h@qq.com committed
496
            next_first_line = layout_paras[i][0][0]
497
498
        except Exception:
            logger.error(f'page layout {i} has no line')
kernel.h@qq.com's avatar
kernel.h@qq.com committed
499
            continue
500
501
502
        pre_last_line_text = ''.join(
            [__get_span_text(span) for span in pre_last_line['spans']]
        )
xuchao's avatar
xuchao committed
503
        pre_last_line_type = pre_last_line['spans'][-1]['type']
504
505
506
        next_first_line_text = ''.join(
            [__get_span_text(span) for span in next_first_line['spans']]
        )
xuchao's avatar
xuchao committed
507
        next_first_line_type = next_first_line['spans'][0]['type']
508
509
510
511
        if pre_last_line_type not in [
            TEXT,
            INLINE_EQUATION,
        ] or next_first_line_type not in [TEXT, INLINE_EQUATION]:
512
            connected_layout_paras.append(layout_paras[i])
xuchao's avatar
xuchao committed
513
            continue
514
515
516
517
518
519
520
521

        pre_x2_max = __find_layout_bbox_by_line(pre_last_line['bbox'], new_layout_bbox)[
            2
        ]
        next_x0_min = __find_layout_bbox_by_line(
            next_first_line['bbox'], new_layout_bbox
        )[0]

xuchao's avatar
xuchao committed
522
523
        pre_last_line_text = pre_last_line_text.strip()
        next_first_line_text = next_first_line_text.strip()
524
525
526
527
528
        if (
            pre_last_line['bbox'][2] == pre_x2_max
            and pre_last_line_text[-1] not in LINE_STOP_FLAG
            and next_first_line['bbox'][0] == next_x0_min
        ):  # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
xuchao's avatar
xuchao committed
529
            """连接段落条件成立,将前一个layout的段落和后一个layout的段落连接。"""
530
            connected_layout_paras[-1][-1].extend(layout_paras[i][0])
531
532
533
534
            layout_paras[i].pop(
                0
            )  # 删除后一个layout的第一个段落, 因为他已经被合并到前一个layout的最后一个段落了。
            if len(layout_paras[i]) == 0:
535
536
537
                layout_paras.pop(i)
            else:
                connected_layout_paras.append(layout_paras[i])
538
        else:
xuchao's avatar
xuchao committed
539
            """连接段落条件不成立,将前一个layout的段落加入到结果中。"""
540
            connected_layout_paras.append(layout_paras[i])
541

xuchao's avatar
xuchao committed
542
543
544
    return connected_layout_paras


545
546
547
548
549
550
551
552
def __connect_para_inter_page(
    pre_page_paras,
    next_page_paras,
    pre_page_layout_bbox,
    next_page_layout_bbox,
    page_num,
    lang,
):
553
554
555
556
557
558
    """
    连接起来相邻两个页面的段落——前一个页面最后一个段落和后一个页面的第一个段落。
    是否可以连接的条件:
    1. 前一个页面的最后一个段落最后一行沾满整个行。并且没有结尾符号。
    2. 后一个页面的第一个段落第一行没有空白开头。
    """
kernel.h@qq.com's avatar
kernel.h@qq.com committed
559
    # 有的页面可能压根没有文字
560
561
562
563
564
565
    if (
        len(pre_page_paras) == 0
        or len(next_page_paras) == 0
        or len(pre_page_paras[0]) == 0
        or len(next_page_paras[0]) == 0
    ):  # TODO [[]]为什么出现在pre_page_paras里?
kernel.h@qq.com's avatar
kernel.h@qq.com committed
566
        return False
567
568
    pre_last_para = pre_page_paras[-1][-1]
    next_first_para = next_page_paras[0][0]
569
570
    pre_last_line = pre_last_para[-1]
    next_first_line = next_first_para[0]
571
572
573
    pre_last_line_text = ''.join(
        [__get_span_text(span) for span in pre_last_line['spans']]
    )
574
    pre_last_line_type = pre_last_line['spans'][-1]['type']
575
576
577
    next_first_line_text = ''.join(
        [__get_span_text(span) for span in next_first_line['spans']]
    )
578
    next_first_line_type = next_first_line['spans'][0]['type']
579
580
581
582
583
584
585
586

    if pre_last_line_type not in [
        TEXT,
        INLINE_EQUATION,
    ] or next_first_line_type not in [
        TEXT,
        INLINE_EQUATION,
    ]:  # TODO,真的要做好,要考虑跨table, image, 行间的情况
587
588
        # 不是文本,不连接
        return False
589
590
591
592
593
594
595
596

    pre_x2_max = __find_layout_bbox_by_line(
        pre_last_line['bbox'], pre_page_layout_bbox
    )[2]
    next_x0_min = __find_layout_bbox_by_line(
        next_first_line['bbox'], next_page_layout_bbox
    )[0]

597
598
    pre_last_line_text = pre_last_line_text.strip()
    next_first_line_text = next_first_line_text.strip()
599
600
601
602
603
    if (
        pre_last_line['bbox'][2] == pre_x2_max
        and pre_last_line_text[-1] not in LINE_STOP_FLAG
        and next_first_line['bbox'][0] == next_x0_min
    ):  # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
604
        """连接段落条件成立,将前一个layout的段落和后一个layout的段落连接。"""
605
        pre_last_para.extend(next_first_para)
606
607
608
        next_page_paras[0].pop(
            0
        )  # 删除后一个页面的第一个段落, 因为他已经被合并到前一个页面的最后一个段落了。
609
610
611
612
        return True
    else:
        return False

613

614
615
616
617
618
619
620
621
622
623
624
625
def find_consecutive_true_regions(input_array):
    start_index = None  # 连续True区域的起始索引
    regions = []  # 用于保存所有连续True区域的起始和结束索引

    for i in range(len(input_array)):
        # 如果我们找到了一个True值,并且当前并没有在连续True区域中
        if input_array[i] and start_index is None:
            start_index = i  # 记录连续True区域的起始索引

        # 如果我们找到了一个False值,并且当前在连续True区域中
        elif not input_array[i] and start_index is not None:
            # 如果连续True区域长度大于1,那么将其添加到结果列表中
626
627
            if i - start_index > 1:
                regions.append((start_index, i - 1))
628
629
630
631
            start_index = None  # 重置起始索引

    # 如果最后一个元素是True,那么需要将最后一个连续True区域加入到结果列表中
    if start_index is not None and len(input_array) - start_index > 1:
632
        regions.append((start_index, len(input_array) - 1))
633
634
635
636

    return regions


637
638
639
def __connect_middle_align_text(
    page_paras, new_layout_bbox, page_num, lang, debug_mode
):
640
641
642
643
644
645
    """
    找出来中间对齐的连续单行文本,如果连续行高度相同,那么合并为一个段落。
    一个line居中的条件是:
    1. 水平中心点跨越layout的中心点。
    2. 左右两侧都有空白
    """
646

647
648
649
650
    for layout_i, layout_para in enumerate(page_paras):
        layout_box = new_layout_bbox[layout_i]
        single_line_paras_tag = []
        for i in range(len(layout_para)):
651
652
653
654
655
            single_line_paras_tag.append(
                len(layout_para[i]) == 1
                and layout_para[i][0]['spans'][0]['type'] == TEXT
            )

656
        """找出来连续的单行文本,如果连续行高度相同,那么合并为一个段落。"""
657
658
659
660
        consecutive_single_line_indices = find_consecutive_true_regions(
            single_line_paras_tag
        )
        if len(consecutive_single_line_indices) > 0:
661
662
663
664
665
            index_offset = 0
            """检查这些行是否是高度相同的,居中的"""
            for start, end in consecutive_single_line_indices:
                start += index_offset
                end += index_offset
666
667
668
669
670
671
672
673
674
675
                line_hi = np.array(
                    [
                        line[0]['bbox'][3] - line[0]['bbox'][1]
                        for line in layout_para[start : end + 1]
                    ]
                )
                first_line_text = ''.join(
                    [__get_span_text(span) for span in layout_para[start][0]['spans']]
                )
                if 'Table' in first_line_text or 'Figure' in first_line_text:
676
                    pass
677
                if debug_mode:
kernel.h@qq.com's avatar
kernel.h@qq.com committed
678
                    logger.debug(line_hi.std())
679
680
681
682
683
684
685
686
687

                if line_hi.std() < 2:
                    """行高度相同,那么判断是否居中."""
                    all_left_x0 = [
                        line[0]['bbox'][0] for line in layout_para[start : end + 1]
                    ]
                    all_right_x1 = [
                        line[0]['bbox'][2] for line in layout_para[start : end + 1]
                    ]
688
                    layout_center = (layout_box[0] + layout_box[2]) / 2
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
                    if (
                        all(
                            [
                                x0 < layout_center < x1
                                for x0, x1 in zip(all_left_x0, all_right_x1)
                            ]
                        )
                        and not all([x0 == layout_box[0] for x0 in all_left_x0])
                        and not all([x1 == layout_box[2] for x1 in all_right_x1])
                    ):
                        merge_para = [l[0] for l in layout_para[start : end + 1]]  # noqa: E741
                        para_text = ''.join(
                            [
                                __get_span_text(span)
                                for line in merge_para
                                for span in line['spans']
                            ]
                        )
707
                        if debug_mode:
kernel.h@qq.com's avatar
kernel.h@qq.com committed
708
                            logger.debug(para_text)
709
710
711
                        layout_para[start : end + 1] = [merge_para]
                        index_offset -= end - start

712
    return
713

714
715

def __merge_signle_list_text(page_paras, new_layout_bbox, page_num, lang):
716
717
    """找出来连续的单行文本,如果首行顶格,接下来的几个单行段落缩进对齐,那么合并为一个段落。"""

718
719
    pass

720

kernel.h@qq.com's avatar
kernel.h@qq.com committed
721
def __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang):
722
    """根据line和layout情况进行分段 先实现一个根据行末尾特征分段的简单方法。"""
xuchao's avatar
xuchao committed
723
724
725
726
727
728
729
    """
    算法思路:
    1. 扫描layout里每一行,找出来行尾距离layout有边界有一定距离的行。
    2. 从上述行中找到末尾是句号等可作为断行标志的行。
    3. 参照上述行尾特征进行分段。
    4. 图、表,目前独占一行,不考虑分段。
    """
730
    if page_num == 343:
kernel.h@qq.com's avatar
kernel.h@qq.com committed
731
        pass
732
733
734
735
736
737
738
739
740
741
742
    lines_group = __group_line_by_layout(blocks, layout_bboxes, lang)  # block内分段
    layout_paras, layout_list_info = __split_para_in_layoutbox(
        lines_group, new_layout_bbox, lang
    )  # layout内分段
    layout_paras2, page_list_info = __connect_list_inter_layout(
        layout_paras, new_layout_bbox, layout_list_info, page_num, lang
    )  # layout之间连接列表段落
    connected_layout_paras = __connect_para_inter_layoutbox(
        layout_paras2, new_layout_bbox, lang
    )  # layout间链接段落

kernel.h@qq.com's avatar
kernel.h@qq.com committed
743
    return connected_layout_paras, page_list_info
744
745
746
747
748
749


def para_split(pdf_info_dict, debug_mode, lang='en'):
    """根据line和layout情况进行分段."""
    new_layout_of_pages = []  # 数组的数组,每个元素是一个页面的layoutS
    all_page_list_info = []  # 保存每个页面开头和结尾是否是列表
kernel.h@qq.com's avatar
kernel.h@qq.com committed
750
    for page_num, page in pdf_info_dict.items():
751
752
753
754
        blocks = page['preproc_blocks']
        layout_bboxes = page['layout_bboxes']
        new_layout_bbox = __common_pre_proc(blocks, layout_bboxes)
        new_layout_of_pages.append(new_layout_bbox)
755
756
757
        splited_blocks, page_list_info = __do_split_page(
            blocks, layout_bboxes, new_layout_bbox, page_num, lang
        )
kernel.h@qq.com's avatar
kernel.h@qq.com committed
758
        all_page_list_info.append(page_list_info)
759
        page['para_blocks'] = splited_blocks
760

761
762
    """连接页面与页面之间的可能合并的段落"""
    pdf_infos = list(pdf_info_dict.values())
kernel.h@qq.com's avatar
kernel.h@qq.com committed
763
    for page_num, page in enumerate(pdf_info_dict.values()):
764
        if page_num == 0:
765
            continue
766
        pre_page_paras = pdf_infos[page_num - 1]['para_blocks']
kernel.h@qq.com's avatar
kernel.h@qq.com committed
767
        next_page_paras = pdf_infos[page_num]['para_blocks']
768
        pre_page_layout_bbox = new_layout_of_pages[page_num - 1]
kernel.h@qq.com's avatar
kernel.h@qq.com committed
769
        next_page_layout_bbox = new_layout_of_pages[page_num]
770
771
772
773
774
775
776
777
778

        is_conn = __connect_para_inter_page(
            pre_page_paras,
            next_page_paras,
            pre_page_layout_bbox,
            next_page_layout_bbox,
            page_num,
            lang,
        )
779
780
        if debug_mode:
            if is_conn:
781
782
783
784
785
786
787
788
789
790
791
792
                logger.info(f'连接了第{page_num-1}页和第{page_num}页的段落')

        is_list_conn = __connect_list_inter_page(
            pre_page_paras,
            next_page_paras,
            pre_page_layout_bbox,
            next_page_layout_bbox,
            all_page_list_info[page_num - 1],
            all_page_list_info[page_num],
            page_num,
            lang,
        )
793
794
        if debug_mode:
            if is_list_conn:
795
796
                logger.info(f'连接了第{page_num-1}页和第{page_num}页的列表段落')

797
798
799
800
801
802
803
    """接下来可能会漏掉一些特别的一些可以合并的内容,对他们进行段落连接
    1. 正文中有时出现一个行顶格,接下来几行缩进的情况。
    2. 居中的一些连续单行,如果高度相同,那么可能是一个段落。
    """
    for page_num, page in enumerate(pdf_info_dict.values()):
        page_paras = page['para_blocks']
        new_layout_bbox = new_layout_of_pages[page_num]
804
805
806
        __connect_middle_align_text(
            page_paras, new_layout_bbox, page_num, lang, debug_mode=debug_mode
        )
807
        __merge_signle_list_text(page_paras, new_layout_bbox, page_num, lang)