tokenize_mask.py 9.57 KB
Newer Older
1
2
import argparse
import multiprocessing
mandoxzhang's avatar
mandoxzhang committed
3
import os
4
5
6
7
import time
from random import shuffle

import h5py
mandoxzhang's avatar
mandoxzhang committed
8
import numpy as np
9
10
import psutil
from get_mask import PreTrainingDataset
mandoxzhang's avatar
mandoxzhang committed
11
12
13
14
15
16
from tqdm import tqdm
from transformers import AutoTokenizer


def get_raw_instance(document, max_sequence_length=512):
    """
17
18
    Get the initial training instances, split the whole segment into multiple parts according to the max_sequence_length, and return as multiple processed instances.
    :param document: document
mandoxzhang's avatar
mandoxzhang committed
19
20
21
22
23
24
25
26
27
    :param max_sequence_length:
    :return: a list. each element is a sequence of text
    """
    # document = self.documents[index]
    max_sequence_length_allowed = max_sequence_length - 2
    # document = [seq for seq in document if len(seq)<max_sequence_length_allowed]
    sizes = [len(seq) for seq in document]

    result_list = []
28
    curr_seq = []
mandoxzhang's avatar
mandoxzhang committed
29
30
    sz_idx = 0
    while sz_idx < len(sizes):
31
        if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed:  # or len(curr_seq)==0:
mandoxzhang's avatar
mandoxzhang committed
32
33
34
35
36
37
            curr_seq += document[sz_idx]
            sz_idx += 1
        elif sizes[sz_idx] >= max_sequence_length_allowed:
            if len(curr_seq) > 0:
                result_list.append(curr_seq)
            curr_seq = []
38
            result_list.append(document[sz_idx][:max_sequence_length_allowed])
mandoxzhang's avatar
mandoxzhang committed
39
40
41
42
            sz_idx += 1
        else:
            result_list.append(curr_seq)
            curr_seq = []
43

44
    if len(curr_seq) > max_sequence_length_allowed / 2:  # /2
mandoxzhang's avatar
mandoxzhang committed
45
46
47
48
        result_list.append(curr_seq)

    # num_instance=int(len(big_list)/max_sequence_length_allowed)+1
    # print("num_instance:",num_instance)
49

mandoxzhang's avatar
mandoxzhang committed
50
51
52
53
54
55
56
57
58
59
60
61
62
    # result_list=[]
    # for j in range(num_instance):
    #     index=j*max_sequence_length_allowed
    #     end_index=index+max_sequence_length_allowed if j!=num_instance-1 else -1
    #     result_list.append(big_list[index:end_index])
    return result_list


def split_numpy_chunk(path, tokenizer, pretrain_data, host):
    documents = []
    instances = []

    s = time.time()
63
    with open(path, encoding="utf-8") as fd:
mandoxzhang's avatar
mandoxzhang committed
64
65
66
67
68
69
        document = []
        for i, line in enumerate(tqdm(fd)):
            line = line.strip()
            # document = line
            # if len(document.split("<sep>")) <= 3:
            #     continue
70
            if len(line) > 0 and line[:2] == "]]":  # This is end of document
mandoxzhang's avatar
mandoxzhang committed
71
72
73
74
75
76
                documents.append(document)
                document = []
            elif len(line) >= 2:
                document.append(line)
        if len(document) > 0:
            documents.append(document)
77
    print("read_file ", time.time() - s)
mandoxzhang's avatar
mandoxzhang committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

    # documents = [x for x in documents if x]
    # print(len(documents))
    # print(len(documents[0]))
    # print(documents[0][0:10])

    ans = []
    for docs in tqdm(documents):
        ans.append(pretrain_data.tokenize(docs))
    print(time.time() - s)
    del documents

    instances = []
    for a in tqdm(ans):
        raw_ins = get_raw_instance(a)
        instances.extend(raw_ins)
    del ans
95

96
    print("len instance", len(instances))
mandoxzhang's avatar
mandoxzhang committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

    sen_num = len(instances)
    seq_len = 512
    input_ids = np.zeros([sen_num, seq_len], dtype=np.int32)
    input_mask = np.zeros([sen_num, seq_len], dtype=np.int32)
    segment_ids = np.zeros([sen_num, seq_len], dtype=np.int32)
    masked_lm_output = np.zeros([sen_num, seq_len], dtype=np.int32)

    for index, ins in tqdm(enumerate(instances)):
        mask_dict = pretrain_data.create_training_instance(ins)
        input_ids[index] = mask_dict[0]
        input_mask[index] = mask_dict[1]
        segment_ids[index] = mask_dict[2]
        masked_lm_output[index] = mask_dict[3]

112
    with h5py.File(f"/output/{host}.h5", "w") as hf:
113
114
115
116
        hf.create_dataset("input_ids", data=input_ids)
        hf.create_dataset("input_mask", data=input_ids)
        hf.create_dataset("segment_ids", data=segment_ids)
        hf.create_dataset("masked_lm_positions", data=masked_lm_output)
mandoxzhang's avatar
mandoxzhang committed
117
118
119
120

    del instances


121
def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_factor, seq_len, file_name):
122
123
    if os.path.exists(os.path.join(output_path, f"{file_name}.h5")):
        print(f"{file_name}.h5 exists")
mandoxzhang's avatar
mandoxzhang committed
124
125
126
127
128
129
        return

    documents = []
    instances = []

    s = time.time()
130
    with open(input_path, "r", encoding="utf-8") as fd:
mandoxzhang's avatar
mandoxzhang committed
131
132
133
        document = []
        for i, line in enumerate(tqdm(fd)):
            line = line.strip()
134
            if len(line) > 0 and line[:2] == "]]":  # This is end of document
mandoxzhang's avatar
mandoxzhang committed
135
136
137
138
139
140
                documents.append(document)
                document = []
            elif len(line) >= 2:
                document.append(line)
        if len(document) > 0:
            documents.append(document)
141
    print(f"read_file cost {time.time() - s}, length is {len(documents)}")
142

mandoxzhang's avatar
mandoxzhang committed
143
144
145
146
    ans = []
    s = time.time()
    pool = multiprocessing.Pool(worker)
    encoded_doc = pool.imap_unordered(pretrain_data.tokenize, documents, 100)
147
    for index, res in tqdm(enumerate(encoded_doc, start=1), total=len(documents), colour="cyan"):
mandoxzhang's avatar
mandoxzhang committed
148
149
150
151
152
153
        ans.append(res)
    pool.close()
    print((time.time() - s) / 60)
    del documents

    instances = []
154
    for a in tqdm(ans, colour="MAGENTA"):
mandoxzhang's avatar
mandoxzhang committed
155
156
157
        raw_ins = get_raw_instance(a, max_sequence_length=seq_len)
        instances.extend(raw_ins)
    del ans
158

159
    print("len instance", len(instances))
mandoxzhang's avatar
mandoxzhang committed
160
161
162
163
164
165
166
167

    new_instances = []
    for _ in range(dupe_factor):
        for ins in instances:
            new_instances.append(ins)

    shuffle(new_instances)
    instances = new_instances
168
    print("after dupe_factor, len instance", len(instances))
mandoxzhang's avatar
mandoxzhang committed
169
170
171
172
173
174
175
176
177
178

    sentence_num = len(instances)
    input_ids = np.zeros([sentence_num, seq_len], dtype=np.int32)
    input_mask = np.zeros([sentence_num, seq_len], dtype=np.int32)
    segment_ids = np.zeros([sentence_num, seq_len], dtype=np.int32)
    masked_lm_output = np.zeros([sentence_num, seq_len], dtype=np.int32)

    s = time.time()
    pool = multiprocessing.Pool(worker)
    encoded_docs = pool.imap_unordered(pretrain_data.create_training_instance, instances, 32)
179
    for index, mask_dict in tqdm(enumerate(encoded_docs), total=len(instances), colour="blue"):
mandoxzhang's avatar
mandoxzhang committed
180
181
182
183
184
185
186
        input_ids[index] = mask_dict[0]
        input_mask[index] = mask_dict[1]
        segment_ids[index] = mask_dict[2]
        masked_lm_output[index] = mask_dict[3]
    pool.close()
    print((time.time() - s) / 60)

187
    with h5py.File(os.path.join(output_path, f"{file_name}.h5"), "w") as hf:
188
189
190
191
        hf.create_dataset("input_ids", data=input_ids)
        hf.create_dataset("input_mask", data=input_mask)
        hf.create_dataset("segment_ids", data=segment_ids)
        hf.create_dataset("masked_lm_positions", data=masked_lm_output)
mandoxzhang's avatar
mandoxzhang committed
192
193
194
195

    del instances


196
if __name__ == "__main__":
mandoxzhang's avatar
mandoxzhang committed
197
    parser = argparse.ArgumentParser()
198
199
200
201
202
203
204
205
206
207
    parser.add_argument("--tokenizer_path", type=str, required=True, default=10, help="path of tokenizer")
    parser.add_argument("--seq_len", type=int, default=512, help="sequence length")
    parser.add_argument(
        "--max_predictions_per_seq", type=int, default=80, help="number of shards, e.g., 10, 50, or 100"
    )
    parser.add_argument("--input_path", type=str, required=True, help="input path of shard which has split sentence")
    parser.add_argument("--output_path", type=str, required=True, help="output path of h5 contains token id")
    parser.add_argument(
        "--backend", type=str, default="python", help="backend of mask token, python, c++, numpy respectively"
    )
208
    parser.add_argument(
209
        "--dupe_factor",
210
211
        type=int,
        default=1,
212
213
214
215
        help="specifies how many times the preprocessor repeats to create the input from the same article/document",
    )
    parser.add_argument("--worker", type=int, default=32, help="number of process")
    parser.add_argument("--server_num", type=int, default=10, help="number of servers")
mandoxzhang's avatar
mandoxzhang committed
216
217
218
    args = parser.parse_args()

    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
219
220
221
    pretrain_data = PreTrainingDataset(
        tokenizer, args.seq_len, args.backend, max_predictions_per_seq=args.max_predictions_per_seq
    )
222

mandoxzhang's avatar
mandoxzhang committed
223
224
225
    data_len = len(os.listdir(args.input_path))

    for i in range(data_len):
226
        input_path = os.path.join(args.input_path, f"{i}.txt")
mandoxzhang's avatar
mandoxzhang committed
227
228
        if os.path.exists(input_path):
            start = time.time()
229
230
231
232
            print(f"process {input_path}")
            split_numpy_chunk_pool(
                input_path, args.output_path, pretrain_data, args.worker, args.dupe_factor, args.seq_len, i
            )
mandoxzhang's avatar
mandoxzhang committed
233
            end_ = time.time()
234
235
236
237
            print("memory:%.4f GB" % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024))
            print(f"has cost {(end_ - start) / 60}")
            print("-" * 100)
            print("")
mandoxzhang's avatar
mandoxzhang committed
238
239
240
241
242
243
244
245
246
247

    # if you have multiple server, you can use code below or modify code to openmpi

    # host = int(socket.gethostname().split('GPU')[-1])
    # for i in range(data_len // args.server_num + 1):
    #     h = args.server_num * i + host - 1
    #     input_path = os.path.join(args.input_path, f'{h}.txt')
    #     if os.path.exists(input_path):
    #         start = time.time()
    #         print(f'I am server {host}, process {input_path}')
248
249
250
    #         split_numpy_chunk_pool(input_path,
    #                                 args.output_path,
    #                                 pretrain_data,
mandoxzhang's avatar
mandoxzhang committed
251
252
253
254
255
256
257
258
259
    #                                 args.worker,
    #                                 args.dupe_factor,
    #                                 args.seq_len,
    #                                 h)
    #         end_ = time.time()
    #         print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )
    #         print(f'has cost {(end_ - start) / 60}')
    #         print('-' * 100)
    #         print('')