sentence_split.py 5.27 KB
Newer Older
1
2
3
import argparse
import functools
import json
mandoxzhang's avatar
mandoxzhang committed
4
5
6
7
import multiprocessing
import os
import re
import time
8
9
10
11
from typing import List

from tqdm import tqdm

mandoxzhang's avatar
mandoxzhang committed
12
13
14
15
16

def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]:
    sent_list = []
    try:
        if flag == "zh":
17
            document = re.sub('(?P<quotation_mark>([。?!…](?![”’"\'])))', r'\g<quotation_mark>\n', document)
18
            document = re.sub('(?P<quotation_mark>([。?!]|…{1,2})[”’"\'])', r'\g<quotation_mark>\n', document)
mandoxzhang's avatar
mandoxzhang committed
19
        elif flag == "en":
20
21
22
            document = re.sub('(?P<quotation_mark>([.?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)
            document = re.sub('(?P<quotation_mark>([?!.]["\']))', r'\g<quotation_mark>\n',
                              document)    # Special quotation marks
mandoxzhang's avatar
mandoxzhang committed
23
        else:
24
25
            document = re.sub('(?P<quotation_mark>([。?!….?!](?![”’"\'])))', r'\g<quotation_mark>\n', document)

mandoxzhang's avatar
mandoxzhang committed
26
            document = re.sub('(?P<quotation_mark>(([。?!.!?]|…{1,2})[”’"\']))', r'\g<quotation_mark>\n',
27
                              document)    # Special quotation marks
mandoxzhang's avatar
mandoxzhang committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

        sent_list_ori = document.splitlines()
        for sent in sent_list_ori:
            sent = sent.strip()
            if not sent:
                continue
            elif len(sent) <= 2:
                continue
            else:
                while len(sent) > limit:
                    temp = sent[0:limit]
                    sent_list.append(temp)
                    sent = sent[limit:]
                sent_list.append(sent)
    except:
        sent_list.clear()
        sent_list.append(document)
    return sent_list


48
def get_sent(output_path, input_path, fin_list=[], host=-1, seq_len=512) -> None:
mandoxzhang's avatar
mandoxzhang committed
49
50
51
52
53

    workers = 32

    if input_path[-1] == '/':
        input_path = input_path[:-1]
54

mandoxzhang's avatar
mandoxzhang committed
55
    cur_path = os.path.join(output_path, str(host) + '.txt')
56
    new_split_sentence = functools.partial(split_sentence, limit=seq_len - 2)
mandoxzhang's avatar
mandoxzhang committed
57
58
59
60
61
62
63
64
    with open(cur_path, 'w', encoding='utf-8') as f:
        for fi, fin_path in enumerate(fin_list):
            if not os.path.exists(os.path.join(input_path, fin_path[0])):
                continue
            if '.json' not in fin_path[0]:
                continue

            print("Processing ", fin_path[0], " ", fi)
65

mandoxzhang's avatar
mandoxzhang committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
            with open(os.path.join(input_path, fin_path[0]), 'r') as fin:
                f_data = [l['content'] for l in json.load(fin)]

                pool = multiprocessing.Pool(workers)
                all_sent = pool.imap_unordered(new_split_sentence, f_data, 32)
                pool.close()
            print('finished..')

            cnt = 0
            for d in tqdm(all_sent):
                for i in d:
                    f.write(i.strip() + '\n')
                f.write(']]' + '\n')
                cnt += 1
                # if cnt >= 2:
                #     exit()


def getFileSize(filepath, shard):
    all_data = []
    for i in os.listdir(filepath):
        all_data.append(os.path.join(filepath, i))
    all_size = sum([os.path.getsize(os.path.join(filepath, f)) for f in all_data])
    ans = [[f.split('/')[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data]
    ans = sorted(ans, key=lambda x: x[1], reverse=True)
    per_size = all_size / shard
    real_shard = []
    temp = []
    accu_size = 0
    for i in ans:
        accu_size += i[1]
        temp.append(i)
        if accu_size > per_size:
            real_shard.append(temp)
            accu_size = 0
            temp = []
102

mandoxzhang's avatar
mandoxzhang committed
103
104
    if len(temp) > 0:
        real_shard.append(temp)
105

mandoxzhang's avatar
mandoxzhang committed
106
107
108
109
110
111
    return real_shard


def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'):
    import socket
    host = int(socket.gethostname().split(server_name)[-1])
112

mandoxzhang's avatar
mandoxzhang committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    fin_list = real_shard[server_num * base + host - 1]
    print(fin_list)
    print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}')
    return fin_list, host


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--server_num', type=int, default=10, help='number of servers')
    parser.add_argument('--seq_len', type=int, default=512, help='sequence length')
    parser.add_argument('--shard', type=int, default=100, help='number of shards, e.g., 10, 50, or 100')
    parser.add_argument('--input_path', type=str, required=True, help='input path of original corpus')
    parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence')
    args = parser.parse_args()

129
    server_num = args.server_num
mandoxzhang's avatar
mandoxzhang committed
130
    seq_len = args.seq_len
131
    shard = args.shard
mandoxzhang's avatar
mandoxzhang committed
132
    input_path = args.input_path
133
    output_path = args.output_path
mandoxzhang's avatar
mandoxzhang committed
134
135
136
137
138

    real_shard = getFileSize(input_path, shard)

    start = time.time()
    for index, shard in enumerate(real_shard):
139
        get_sent(output_path, input_path, fin_list=shard, host=index, seq_len=seq_len)
mandoxzhang's avatar
mandoxzhang committed
140
141
142
    print(f'cost {str(time.time() - start)}')

    # if you have multiple server, you can use code below or modify code to openmpi
143

mandoxzhang's avatar
mandoxzhang committed
144
145
    # for i in range(len(real_shard) // server_num + 1):
    #     fin_list, host = get_start_end(real_shard, i)
146

mandoxzhang's avatar
mandoxzhang committed
147
148
149
150
151
152
    #     start = time.time()
    #     get_sent(output_path,
    #             input_path,
    #             fin_list=fin_list, host= 10 * i + host - 1)

    #     print(f'cost {str(time.time() - start)}')