preprocess.py 13 KB
Newer Older
Louis Martin's avatar
Louis Martin committed
1
#!/usr/bin/env python3
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Data pre-processing: build vocabularies and binarize training data.
"""
Sergey Edunov's avatar
Sergey Edunov committed
11
12

import argparse
Sergey Edunov's avatar
Sergey Edunov committed
13
from collections import Counter
Sergey Edunov's avatar
Sergey Edunov committed
14
from itertools import zip_longest
15
16
import os
import shutil
Sergey Edunov's avatar
Sergey Edunov committed
17

Sergey Edunov's avatar
Sergey Edunov committed
18

alexeib's avatar
alexeib committed
19
from fairseq.data import indexed_dataset, dictionary
Myle Ott's avatar
Myle Ott committed
20
from fairseq.tokenizer import Tokenizer, tokenize_line
Sergey Edunov's avatar
Sergey Edunov committed
21
22
from multiprocessing import Pool, Manager, Process

Sergey Edunov's avatar
Sergey Edunov committed
23

Myle Ott's avatar
Myle Ott committed
24
def get_parser():
Myle Ott's avatar
Myle Ott committed
25
    parser = argparse.ArgumentParser()
Ruty Rinott's avatar
Ruty Rinott committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    parser.add_argument(
        "-s", "--source-lang", default=None, metavar="SRC", help="source language"
    )
    parser.add_argument(
        "-t", "--target-lang", default=None, metavar="TARGET", help="target language"
    )
    parser.add_argument(
        "--trainpref", metavar="FP", default=None, help="train file prefix"
    )
    parser.add_argument(
        "--validpref",
        metavar="FP",
        default=None,
        help="comma separated, valid file prefixes",
    )
    parser.add_argument(
        "--testpref",
        metavar="FP",
        default=None,
        help="comma separated, test file prefixes",
    )
    parser.add_argument(
        "--destdir", metavar="DIR", default="data-bin", help="destination dir"
    )
    parser.add_argument(
        "--thresholdtgt",
        metavar="N",
        default=0,
        type=int,
        help="map words appearing less than threshold times to unknown",
    )
    parser.add_argument(
        "--thresholdsrc",
        metavar="N",
        default=0,
        type=int,
        help="map words appearing less than threshold times to unknown",
    )
    parser.add_argument("--tgtdict", metavar="FP", help="reuse given target dictionary")
    parser.add_argument("--srcdict", metavar="FP", help="reuse given source dictionary")
    parser.add_argument(
        "--nwordstgt",
        metavar="N",
        default=-1,
        type=int,
        help="number of target words to retain",
    )
    parser.add_argument(
        "--nwordssrc",
        metavar="N",
        default=-1,
        type=int,
        help="number of source words to retain",
    )
    parser.add_argument(
        "--alignfile",
        metavar="ALIGN",
        default=None,
        help="an alignment file (optional)",
    )
    parser.add_argument(
        "--output-format",
        metavar="FORMAT",
        default="binary",
        choices=["binary", "raw"],
        help="output format (optional)",
    )
    parser.add_argument(
        "--joined-dictionary", action="store_true", help="Generate joined dictionary"
    )
    parser.add_argument(
        "--only-source", action="store_true", help="Only process the source language"
    )
    parser.add_argument(
        "--padding-factor",
        metavar="N",
        default=8,
        type=int,
        help="Pad dictionary size to be multiple of N",
    )
    parser.add_argument(
        "--workers", metavar="N", default=1, type=int, help="number of parallel workers"
    )
Myle Ott's avatar
Myle Ott committed
109
    return parser
Sergey Edunov's avatar
Sergey Edunov committed
110

Myle Ott's avatar
Myle Ott committed
111

Myle Ott's avatar
Myle Ott committed
112
def main(args):
Sergey Edunov's avatar
Sergey Edunov committed
113
114
    print(args)
    os.makedirs(args.destdir, exist_ok=True)
115
    target = not args.only_source
Sergey Edunov's avatar
Sergey Edunov committed
116

alexeib's avatar
alexeib committed
117
    def train_path(lang):
Ruty Rinott's avatar
Ruty Rinott committed
118
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")
alexeib's avatar
alexeib committed
119
120
121
122

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
Myle Ott's avatar
Myle Ott committed
123
            fname += ".{lang}".format(lang=lang)
alexeib's avatar
alexeib committed
124
125
126
127
128
129
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
Ruty Rinott's avatar
Ruty Rinott committed
130
        return dest_path("dict", lang) + ".txt"
alexeib's avatar
alexeib committed
131

Myle Ott's avatar
Myle Ott committed
132
    if args.joined_dictionary:
Ruty Rinott's avatar
Ruty Rinott committed
133
134
135
136
137
138
        assert not args.srcdict, "cannot combine --srcdict and --joined-dictionary"
        assert not args.tgtdict, "cannot combine --tgtdict and --joined-dictionary"
        src_dict = build_dictionary(
            {train_path(lang) for lang in [args.source_lang, args.target_lang]},
            args.workers,
        )
Myle Ott's avatar
Myle Ott committed
139
        tgt_dict = src_dict
140
    else:
Myle Ott's avatar
Myle Ott committed
141
142
143
        if args.srcdict:
            src_dict = dictionary.Dictionary.load(args.srcdict)
        else:
Ruty Rinott's avatar
Ruty Rinott committed
144
145
146
147
            assert (
                args.trainpref
            ), "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)], args.workers)
148
149
150
151
        if target:
            if args.tgtdict:
                tgt_dict = dictionary.Dictionary.load(args.tgtdict)
            else:
Ruty Rinott's avatar
Ruty Rinott committed
152
153
154
155
156
157
                assert (
                    args.trainpref
                ), "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary(
                    [train_path(args.target_lang)], args.workers
                )
Myle Ott's avatar
Myle Ott committed
158

Myle Ott's avatar
Myle Ott committed
159
160
161
162
163
    src_dict.finalize(
        threshold=args.thresholdsrc,
        nwords=args.nwordssrc,
        padding_factor=args.padding_factor,
    )
alexeib's avatar
alexeib committed
164
    src_dict.save(dict_path(args.source_lang))
165
    if target:
Myle Ott's avatar
Myle Ott committed
166
167
168
169
170
171
        if not args.joined_dictionary:
            tgt_dict.finalize(
                threshold=args.thresholdtgt,
                nwords=args.nwordstgt,
                padding_factor=args.padding_factor,
            )
alexeib's avatar
alexeib committed
172
        tgt_dict.save(dict_path(args.target_lang))
Sergey Edunov's avatar
Sergey Edunov committed
173

Sergey Edunov's avatar
Sergey Edunov committed
174
    def make_binary_dataset(input_prefix, output_prefix, lang, num_workers):
alexeib's avatar
alexeib committed
175
        dict = dictionary.Dictionary.load(dict_path(lang))
Ruty Rinott's avatar
Ruty Rinott committed
176
        print("| [{}] Dictionary: {} types".format(lang, len(dict) - 1))
Sergey Edunov's avatar
Sergey Edunov committed
177
178
        n_seq_tok = [0, 0]
        replaced = Counter()
Sergey Edunov's avatar
Sergey Edunov committed
179

Sergey Edunov's avatar
Sergey Edunov committed
180
        def merge_result(worker_result):
Ruty Rinott's avatar
Ruty Rinott committed
181
182
183
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]
Sergey Edunov's avatar
Sergey Edunov committed
184

Ruty Rinott's avatar
Ruty Rinott committed
185
186
187
        input_file = "{}{}".format(
            input_prefix, ("." + lang) if lang is not None else ""
        )
Sergey Edunov's avatar
Sergey Edunov committed
188
189
190
        offsets = Tokenizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
Ruty Rinott's avatar
Ruty Rinott committed
191
            pool = Pool(processes=num_workers - 1)
Sergey Edunov's avatar
Sergey Edunov committed
192
193
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
Ruty Rinott's avatar
Ruty Rinott committed
194
195
196
197
198
199
200
201
202
203
204
205
206
                pool.apply_async(
                    binarize,
                    (
                        args,
                        input_file,
                        dict,
                        prefix,
                        lang,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                    ),
                    callback=merge_result,
                )
Sergey Edunov's avatar
Sergey Edunov committed
207
208
            pool.close()

Ruty Rinott's avatar
Ruty Rinott committed
209
210
211
212
213
214
215
216
        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, "bin")
        )
        merge_result(
            Tokenizer.binarize(
                input_file, dict, lambda t: ds.add_item(t), offset=0, end=offsets[1]
            )
        )
Sergey Edunov's avatar
Sergey Edunov committed
217
218
219
220
221
222
223
224
225
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

Ruty Rinott's avatar
Ruty Rinott committed
226
        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
Sergey Edunov's avatar
Sergey Edunov committed
227

Ruty Rinott's avatar
Ruty Rinott committed
228
229
230
231
232
233
234
235
236
237
        print(
            "| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
                lang,
                input_file,
                n_seq_tok[0],
                n_seq_tok[1],
                100 * sum(replaced.values()) / n_seq_tok[1],
                dict.unk_word,
            )
        )
Sergey Edunov's avatar
Sergey Edunov committed
238

Sergey Edunov's avatar
Sergey Edunov committed
239
    def make_dataset(input_prefix, output_prefix, lang, num_workers=1):
Ruty Rinott's avatar
Ruty Rinott committed
240
        if args.output_format == "binary":
Sergey Edunov's avatar
Sergey Edunov committed
241
            make_binary_dataset(input_prefix, output_prefix, lang, num_workers)
Ruty Rinott's avatar
Ruty Rinott committed
242
        elif args.output_format == "raw":
243
            # Copy original text file to destination folder
244
            output_text_file = dest_path(
Ruty Rinott's avatar
Ruty Rinott committed
245
                output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
246
247
                lang,
            )
alexeib's avatar
alexeib committed
248
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)
249

250
    def make_all(lang):
251
        if args.trainpref:
Ruty Rinott's avatar
Ruty Rinott committed
252
            make_dataset(args.trainpref, "train", lang, num_workers=args.workers)
253
        if args.validpref:
Ruty Rinott's avatar
Ruty Rinott committed
254
255
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
256
                make_dataset(validpref, outprefix, lang)
257
        if args.testpref:
Ruty Rinott's avatar
Ruty Rinott committed
258
259
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
260
                make_dataset(testpref, outprefix, lang)
261

262
    make_all(args.source_lang)
263
    if target:
264
        make_all(args.target_lang)
265

Ruty Rinott's avatar
Ruty Rinott committed
266
    print("| Wrote preprocessed data to {}".format(args.destdir))
Sergey Edunov's avatar
Sergey Edunov committed
267
268

    if args.alignfile:
269
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
alexeib's avatar
alexeib committed
270
271
272
273
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        src_dict = dictionary.Dictionary.load(dict_path(args.source_lang))
        tgt_dict = dictionary.Dictionary.load(dict_path(args.target_lang))
Sergey Edunov's avatar
Sergey Edunov committed
274
        freq_map = {}
Ruty Rinott's avatar
Ruty Rinott committed
275
276
277
        with open(args.alignfile, "r") as align_file:
            with open(src_file_name, "r") as src_file:
                with open(tgt_file_name, "r") as tgt_file:
Sergey Edunov's avatar
Sergey Edunov committed
278
279
280
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = Tokenizer.tokenize(s, src_dict, add_if_not_exist=False)
                        ti = Tokenizer.tokenize(t, tgt_dict, add_if_not_exist=False)
Ruty Rinott's avatar
Ruty Rinott committed
281
                        ai = list(map(lambda x: tuple(x.split("-")), a.split()))
Sergey Edunov's avatar
Sergey Edunov committed
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)

Ruty Rinott's avatar
Ruty Rinott committed
302
303
304
305
306
307
308
        with open(
            os.path.join(
                args.destdir,
                "alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
            ),
            "w",
        ) as f:
Sergey Edunov's avatar
Sergey Edunov committed
309
            for k, v in align_dict.items():
Ruty Rinott's avatar
Ruty Rinott committed
310
311
312
313
314
315
316
317
318
319
320
321
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)


def build_and_save_dictionary(
    train_path, output_path, num_workers, freq_threshold, max_words
):
    dict = build_dictionary([train_path], num_workers)
    dict.finalize(threshold=freq_threshold, nwords=max_words)
    dict_path = os.path.join(output_path, "dict.txt")
    dict.save(dict_path)
    return dict_path

Sergey Edunov's avatar
Sergey Edunov committed
322

Ruty Rinott's avatar
Ruty Rinott committed
323
324
325
326
327
def build_dictionary(filenames, workers):
    d = dictionary.Dictionary()
    for filename in filenames:
        Tokenizer.add_file_to_dictionary(filename, d, tokenize_line, workers)
    return d
Sergey Edunov's avatar
Sergey Edunov committed
328

Sergey Edunov's avatar
Sergey Edunov committed
329
330

def binarize(args, filename, dict, output_prefix, lang, offset, end):
Ruty Rinott's avatar
Ruty Rinott committed
331
332
333
    ds = indexed_dataset.IndexedDatasetBuilder(
        dataset_dest_file(args, output_prefix, lang, "bin")
    )
Sergey Edunov's avatar
Sergey Edunov committed
334
335
336
337
338

    def consumer(tensor):
        ds.add_item(tensor)

    res = Tokenizer.binarize(filename, dict, consumer, offset=offset, end=end)
Ruty Rinott's avatar
Ruty Rinott committed
339
    ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
Sergey Edunov's avatar
Sergey Edunov committed
340
341
    return res

Ruty Rinott's avatar
Ruty Rinott committed
342
343
344
345
346
347
348

def binarize_with_load(args, filename, dict_path, output_prefix, lang, offset, end):
    dict = dictionary.Dictionary.load(dict_path)
    binarize(args, filename, dict, output_prefix, lang, offset, end)
    return dataset_dest_prefix(args, output_prefix, lang)


Sergey Edunov's avatar
Sergey Edunov committed
349
def dataset_dest_prefix(args, output_prefix, lang):
Ruty Rinott's avatar
Ruty Rinott committed
350
351
352
353
354
    base = f"{args.destdir}/{output_prefix}"
    lang_part = (
        f".{args.source_lang}-{args.target_lang}.{lang}" if lang is not None else ""
    )
    return f"{base}{lang_part}"
Sergey Edunov's avatar
Sergey Edunov committed
355
356
357
358


def dataset_dest_file(args, output_prefix, lang, extension):
    base = dataset_dest_prefix(args, output_prefix, lang)
Ruty Rinott's avatar
Ruty Rinott committed
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    return f"{base}.{extension}"


def get_offsets(input_file, num_workers):
    return Tokenizer.find_offsets(input_file, num_workers)


def merge_files(files, outpath):
    ds = indexed_dataset.IndexedDatasetBuilder("{}.bin".format(outpath))
    for file in files:
        ds.merge_file_(file)
        os.remove(indexed_dataset.data_file_path(file))
        os.remove(indexed_dataset.index_file_path(file))
    ds.finalize("{}.idx".format(outpath))
Sergey Edunov's avatar
Sergey Edunov committed
373
374


Ruty Rinott's avatar
Ruty Rinott committed
375
if __name__ == "__main__":
Myle Ott's avatar
Myle Ott committed
376
377
378
    parser = get_parser()
    args = parser.parse_args()
    main(args)