sentence_split.py 5.26 KB
Newer Older
1
2
3
import argparse
import functools
import json
mandoxzhang's avatar
mandoxzhang committed
4
5
6
7
import multiprocessing
import os
import re
import time
8
9
10
11
from typing import List

from tqdm import tqdm

mandoxzhang's avatar
mandoxzhang committed
12
13
14
15
16

def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]:
    sent_list = []
    try:
        if flag == "zh":
17
18
            document = re.sub("(?P<quotation_mark>([。?!…](?![”’\"'])))", r"\g<quotation_mark>\n", document)
            document = re.sub("(?P<quotation_mark>([。?!]|…{1,2})[”’\"'])", r"\g<quotation_mark>\n", document)
mandoxzhang's avatar
mandoxzhang committed
19
        elif flag == "en":
20
21
22
23
            document = re.sub("(?P<quotation_mark>([.?!](?![”’\"'])))", r"\g<quotation_mark>\n", document)
            document = re.sub(
                "(?P<quotation_mark>([?!.][\"']))", r"\g<quotation_mark>\n", document
            )  # Special quotation marks
mandoxzhang's avatar
mandoxzhang committed
24
        else:
25
            document = re.sub("(?P<quotation_mark>([。?!….?!](?![”’\"'])))", r"\g<quotation_mark>\n", document)
26

27
28
29
            document = re.sub(
                "(?P<quotation_mark>(([。?!.!?]|…{1,2})[”’\"']))", r"\g<quotation_mark>\n", document
            )  # Special quotation marks
mandoxzhang's avatar
mandoxzhang committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

        sent_list_ori = document.splitlines()
        for sent in sent_list_ori:
            sent = sent.strip()
            if not sent:
                continue
            elif len(sent) <= 2:
                continue
            else:
                while len(sent) > limit:
                    temp = sent[0:limit]
                    sent_list.append(temp)
                    sent = sent[limit:]
                sent_list.append(sent)
    except:
        sent_list.clear()
        sent_list.append(document)
    return sent_list


50
def get_sent(output_path, input_path, fin_list=[], host=-1, seq_len=512) -> None:
mandoxzhang's avatar
mandoxzhang committed
51
52
    workers = 32

53
    if input_path[-1] == "/":
mandoxzhang's avatar
mandoxzhang committed
54
        input_path = input_path[:-1]
55

56
    cur_path = os.path.join(output_path, str(host) + ".txt")
57
    new_split_sentence = functools.partial(split_sentence, limit=seq_len - 2)
58
    with open(cur_path, "w", encoding="utf-8") as f:
mandoxzhang's avatar
mandoxzhang committed
59
60
61
        for fi, fin_path in enumerate(fin_list):
            if not os.path.exists(os.path.join(input_path, fin_path[0])):
                continue
62
            if ".json" not in fin_path[0]:
mandoxzhang's avatar
mandoxzhang committed
63
64
65
                continue

            print("Processing ", fin_path[0], " ", fi)
66

67
68
            with open(os.path.join(input_path, fin_path[0]), "r") as fin:
                f_data = [l["content"] for l in json.load(fin)]
mandoxzhang's avatar
mandoxzhang committed
69
70
71
72

                pool = multiprocessing.Pool(workers)
                all_sent = pool.imap_unordered(new_split_sentence, f_data, 32)
                pool.close()
73
            print("finished..")
mandoxzhang's avatar
mandoxzhang committed
74
75
76
77

            cnt = 0
            for d in tqdm(all_sent):
                for i in d:
78
79
                    f.write(i.strip() + "\n")
                f.write("]]" + "\n")
mandoxzhang's avatar
mandoxzhang committed
80
81
82
83
84
85
86
87
88
89
                cnt += 1
                # if cnt >= 2:
                #     exit()


def getFileSize(filepath, shard):
    all_data = []
    for i in os.listdir(filepath):
        all_data.append(os.path.join(filepath, i))
    all_size = sum([os.path.getsize(os.path.join(filepath, f)) for f in all_data])
90
    ans = [[f.split("/")[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data]
mandoxzhang's avatar
mandoxzhang committed
91
92
93
94
95
96
97
98
99
100
101
102
    ans = sorted(ans, key=lambda x: x[1], reverse=True)
    per_size = all_size / shard
    real_shard = []
    temp = []
    accu_size = 0
    for i in ans:
        accu_size += i[1]
        temp.append(i)
        if accu_size > per_size:
            real_shard.append(temp)
            accu_size = 0
            temp = []
103

mandoxzhang's avatar
mandoxzhang committed
104
105
    if len(temp) > 0:
        real_shard.append(temp)
106

mandoxzhang's avatar
mandoxzhang committed
107
108
109
    return real_shard


110
def get_start_end(real_shard, base=0, server_num=10, server_name="GPU"):
mandoxzhang's avatar
mandoxzhang committed
111
    import socket
112

mandoxzhang's avatar
mandoxzhang committed
113
    host = int(socket.gethostname().split(server_name)[-1])
114

mandoxzhang's avatar
mandoxzhang committed
115
116
    fin_list = real_shard[server_num * base + host - 1]
    print(fin_list)
117
    print(f"I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}")
mandoxzhang's avatar
mandoxzhang committed
118
119
120
    return fin_list, host


121
if __name__ == "__main__":
mandoxzhang's avatar
mandoxzhang committed
122
    parser = argparse.ArgumentParser()
123
124
125
126
127
    parser.add_argument("--server_num", type=int, default=10, help="number of servers")
    parser.add_argument("--seq_len", type=int, default=512, help="sequence length")
    parser.add_argument("--shard", type=int, default=100, help="number of shards, e.g., 10, 50, or 100")
    parser.add_argument("--input_path", type=str, required=True, help="input path of original corpus")
    parser.add_argument("--output_path", type=str, required=True, help="output path of shard which has split sentence")
mandoxzhang's avatar
mandoxzhang committed
128
129
    args = parser.parse_args()

130
    server_num = args.server_num
mandoxzhang's avatar
mandoxzhang committed
131
    seq_len = args.seq_len
132
    shard = args.shard
mandoxzhang's avatar
mandoxzhang committed
133
    input_path = args.input_path
134
    output_path = args.output_path
mandoxzhang's avatar
mandoxzhang committed
135
136
137
138
139

    real_shard = getFileSize(input_path, shard)

    start = time.time()
    for index, shard in enumerate(real_shard):
140
        get_sent(output_path, input_path, fin_list=shard, host=index, seq_len=seq_len)
141
    print(f"cost {str(time.time() - start)}")
mandoxzhang's avatar
mandoxzhang committed
142
143

    # if you have multiple server, you can use code below or modify code to openmpi
144

mandoxzhang's avatar
mandoxzhang committed
145
146
    # for i in range(len(real_shard) // server_num + 1):
    #     fin_list, host = get_start_end(real_shard, i)
147

mandoxzhang's avatar
mandoxzhang committed
148
149
150
151
152
153
    #     start = time.time()
    #     get_sent(output_path,
    #             input_path,
    #             fin_list=fin_list, host= 10 * i + host - 1)

    #     print(f'cost {str(time.time() - start)}')