preprocess.py 13 KB
Newer Older
Louis Martin's avatar
Louis Martin committed
1
#!/usr/bin/env python3
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Data pre-processing: build vocabularies and binarize training data.
"""
Sergey Edunov's avatar
Sergey Edunov committed
11
12

import argparse
Sergey Edunov's avatar
Sergey Edunov committed
13
from collections import Counter
Sergey Edunov's avatar
Sergey Edunov committed
14
from itertools import zip_longest
15
16
import os
import shutil
Sergey Edunov's avatar
Sergey Edunov committed
17

Sergey Edunov's avatar
Sergey Edunov committed
18

alexeib's avatar
alexeib committed
19
from fairseq.data import indexed_dataset, dictionary
Myle Ott's avatar
Myle Ott committed
20
from fairseq.tokenizer import Tokenizer, tokenize_line
Myle Ott's avatar
Myle Ott committed
21
from multiprocessing import Pool
Sergey Edunov's avatar
Sergey Edunov committed
22

Sergey Edunov's avatar
Sergey Edunov committed
23

Myle Ott's avatar
Myle Ott committed
24
def get_parser():
Myle Ott's avatar
Myle Ott committed
25
    parser = argparse.ArgumentParser()
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    # fmt: off
    parser.add_argument("-s", "--source-lang", default=None, metavar="SRC",
                        help="source language")
    parser.add_argument("-t", "--target-lang", default=None, metavar="TARGET",
                        help="target language")
    parser.add_argument("--trainpref", metavar="FP", default=None,
                        help="train file prefix")
    parser.add_argument("--validpref", metavar="FP", default=None,
                        help="comma separated, valid file prefixes")
    parser.add_argument("--testpref", metavar="FP", default=None,
                        help="comma separated, test file prefixes")
    parser.add_argument("--destdir", metavar="DIR", default="data-bin",
                        help="destination dir")
    parser.add_argument("--thresholdtgt", metavar="N", default=0, type=int,
                        help="map words appearing less than threshold times to unknown")
    parser.add_argument("--thresholdsrc", metavar="N", default=0, type=int,
                        help="map words appearing less than threshold times to unknown")
    parser.add_argument("--tgtdict", metavar="FP",
                        help="reuse given target dictionary")
    parser.add_argument("--srcdict", metavar="FP",
                        help="reuse given source dictionary")
    parser.add_argument("--nwordstgt", metavar="N", default=-1, type=int,
                        help="number of target words to retain")
    parser.add_argument("--nwordssrc", metavar="N", default=-1, type=int,
                        help="number of source words to retain")
    parser.add_argument("--alignfile", metavar="ALIGN", default=None,
                        help="an alignment file (optional)")
    parser.add_argument("--output-format", metavar="FORMAT", default="binary",
                        choices=["binary", "raw"],
                        help="output format (optional)")
    parser.add_argument("--joined-dictionary", action="store_true",
                        help="Generate joined dictionary")
    parser.add_argument("--only-source", action="store_true",
                        help="Only process the source language")
    parser.add_argument("--padding-factor", metavar="N", default=8, type=int,
                        help="Pad dictionary size to be multiple of N")
    parser.add_argument("--workers", metavar="N", default=1, type=int,
                        help="number of parallel workers")
    # fmt: on
Myle Ott's avatar
Myle Ott committed
65
    return parser
Sergey Edunov's avatar
Sergey Edunov committed
66

Myle Ott's avatar
Myle Ott committed
67

Myle Ott's avatar
Myle Ott committed
68
def main(args):
Sergey Edunov's avatar
Sergey Edunov committed
69
70
    print(args)
    os.makedirs(args.destdir, exist_ok=True)
71
    target = not args.only_source
Sergey Edunov's avatar
Sergey Edunov committed
72

alexeib's avatar
alexeib committed
73
    def train_path(lang):
Ruty Rinott's avatar
Ruty Rinott committed
74
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")
alexeib's avatar
alexeib committed
75
76
77
78

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
Myle Ott's avatar
Myle Ott committed
79
            fname += ".{lang}".format(lang=lang)
alexeib's avatar
alexeib committed
80
81
82
83
84
85
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
Ruty Rinott's avatar
Ruty Rinott committed
86
        return dest_path("dict", lang) + ".txt"
alexeib's avatar
alexeib committed
87

Myle Ott's avatar
Myle Ott committed
88
    if args.joined_dictionary:
Ruty Rinott's avatar
Ruty Rinott committed
89
90
91
92
93
94
        assert not args.srcdict, "cannot combine --srcdict and --joined-dictionary"
        assert not args.tgtdict, "cannot combine --tgtdict and --joined-dictionary"
        src_dict = build_dictionary(
            {train_path(lang) for lang in [args.source_lang, args.target_lang]},
            args.workers,
        )
Myle Ott's avatar
Myle Ott committed
95
        tgt_dict = src_dict
96
    else:
Myle Ott's avatar
Myle Ott committed
97
98
99
        if args.srcdict:
            src_dict = dictionary.Dictionary.load(args.srcdict)
        else:
Ruty Rinott's avatar
Ruty Rinott committed
100
101
102
103
            assert (
                args.trainpref
            ), "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)], args.workers)
104
105
106
107
        if target:
            if args.tgtdict:
                tgt_dict = dictionary.Dictionary.load(args.tgtdict)
            else:
Ruty Rinott's avatar
Ruty Rinott committed
108
109
110
111
112
113
                assert (
                    args.trainpref
                ), "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary(
                    [train_path(args.target_lang)], args.workers
                )
Myle Ott's avatar
Myle Ott committed
114

Myle Ott's avatar
Myle Ott committed
115
116
117
118
119
    src_dict.finalize(
        threshold=args.thresholdsrc,
        nwords=args.nwordssrc,
        padding_factor=args.padding_factor,
    )
alexeib's avatar
alexeib committed
120
    src_dict.save(dict_path(args.source_lang))
121
    if target:
Myle Ott's avatar
Myle Ott committed
122
123
124
125
126
127
        if not args.joined_dictionary:
            tgt_dict.finalize(
                threshold=args.thresholdtgt,
                nwords=args.nwordstgt,
                padding_factor=args.padding_factor,
            )
alexeib's avatar
alexeib committed
128
        tgt_dict.save(dict_path(args.target_lang))
Sergey Edunov's avatar
Sergey Edunov committed
129

Sergey Edunov's avatar
Sergey Edunov committed
130
    def make_binary_dataset(input_prefix, output_prefix, lang, num_workers):
alexeib's avatar
alexeib committed
131
        dict = dictionary.Dictionary.load(dict_path(lang))
Ruty Rinott's avatar
Ruty Rinott committed
132
        print("| [{}] Dictionary: {} types".format(lang, len(dict) - 1))
Sergey Edunov's avatar
Sergey Edunov committed
133
134
        n_seq_tok = [0, 0]
        replaced = Counter()
Sergey Edunov's avatar
Sergey Edunov committed
135

Sergey Edunov's avatar
Sergey Edunov committed
136
        def merge_result(worker_result):
Ruty Rinott's avatar
Ruty Rinott committed
137
138
139
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]
Sergey Edunov's avatar
Sergey Edunov committed
140

Ruty Rinott's avatar
Ruty Rinott committed
141
142
143
        input_file = "{}{}".format(
            input_prefix, ("." + lang) if lang is not None else ""
        )
Sergey Edunov's avatar
Sergey Edunov committed
144
145
146
        offsets = Tokenizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
Ruty Rinott's avatar
Ruty Rinott committed
147
            pool = Pool(processes=num_workers - 1)
Sergey Edunov's avatar
Sergey Edunov committed
148
149
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
Ruty Rinott's avatar
Ruty Rinott committed
150
151
152
153
154
155
156
157
158
159
160
161
162
                pool.apply_async(
                    binarize,
                    (
                        args,
                        input_file,
                        dict,
                        prefix,
                        lang,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                    ),
                    callback=merge_result,
                )
Sergey Edunov's avatar
Sergey Edunov committed
163
164
            pool.close()

Ruty Rinott's avatar
Ruty Rinott committed
165
166
167
168
169
170
171
172
        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, "bin")
        )
        merge_result(
            Tokenizer.binarize(
                input_file, dict, lambda t: ds.add_item(t), offset=0, end=offsets[1]
            )
        )
Sergey Edunov's avatar
Sergey Edunov committed
173
174
175
176
177
178
179
180
181
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

Ruty Rinott's avatar
Ruty Rinott committed
182
        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
Sergey Edunov's avatar
Sergey Edunov committed
183

Ruty Rinott's avatar
Ruty Rinott committed
184
185
186
187
188
189
190
191
192
193
        print(
            "| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
                lang,
                input_file,
                n_seq_tok[0],
                n_seq_tok[1],
                100 * sum(replaced.values()) / n_seq_tok[1],
                dict.unk_word,
            )
        )
Sergey Edunov's avatar
Sergey Edunov committed
194

Sergey Edunov's avatar
Sergey Edunov committed
195
    def make_dataset(input_prefix, output_prefix, lang, num_workers=1):
Ruty Rinott's avatar
Ruty Rinott committed
196
        if args.output_format == "binary":
Sergey Edunov's avatar
Sergey Edunov committed
197
            make_binary_dataset(input_prefix, output_prefix, lang, num_workers)
Ruty Rinott's avatar
Ruty Rinott committed
198
        elif args.output_format == "raw":
199
            # Copy original text file to destination folder
200
            output_text_file = dest_path(
Ruty Rinott's avatar
Ruty Rinott committed
201
                output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
202
203
                lang,
            )
alexeib's avatar
alexeib committed
204
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)
205

206
    def make_all(lang):
207
        if args.trainpref:
Ruty Rinott's avatar
Ruty Rinott committed
208
            make_dataset(args.trainpref, "train", lang, num_workers=args.workers)
209
        if args.validpref:
Ruty Rinott's avatar
Ruty Rinott committed
210
211
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
212
                make_dataset(validpref, outprefix, lang)
213
        if args.testpref:
Ruty Rinott's avatar
Ruty Rinott committed
214
215
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
216
                make_dataset(testpref, outprefix, lang)
217

218
    make_all(args.source_lang)
219
    if target:
220
        make_all(args.target_lang)
221

Ruty Rinott's avatar
Ruty Rinott committed
222
    print("| Wrote preprocessed data to {}".format(args.destdir))
Sergey Edunov's avatar
Sergey Edunov committed
223
224

    if args.alignfile:
225
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
alexeib's avatar
alexeib committed
226
227
228
229
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        src_dict = dictionary.Dictionary.load(dict_path(args.source_lang))
        tgt_dict = dictionary.Dictionary.load(dict_path(args.target_lang))
Sergey Edunov's avatar
Sergey Edunov committed
230
        freq_map = {}
Ruty Rinott's avatar
Ruty Rinott committed
231
232
233
        with open(args.alignfile, "r") as align_file:
            with open(src_file_name, "r") as src_file:
                with open(tgt_file_name, "r") as tgt_file:
Sergey Edunov's avatar
Sergey Edunov committed
234
235
236
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = Tokenizer.tokenize(s, src_dict, add_if_not_exist=False)
                        ti = Tokenizer.tokenize(t, tgt_dict, add_if_not_exist=False)
Ruty Rinott's avatar
Ruty Rinott committed
237
                        ai = list(map(lambda x: tuple(x.split("-")), a.split()))
Sergey Edunov's avatar
Sergey Edunov committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)

Ruty Rinott's avatar
Ruty Rinott committed
258
259
260
261
262
263
264
        with open(
            os.path.join(
                args.destdir,
                "alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
            ),
            "w",
        ) as f:
Sergey Edunov's avatar
Sergey Edunov committed
265
            for k, v in align_dict.items():
Ruty Rinott's avatar
Ruty Rinott committed
266
267
268
269
270
271
272
273
274
275
276
277
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)


def build_and_save_dictionary(
    train_path, output_path, num_workers, freq_threshold, max_words
):
    dict = build_dictionary([train_path], num_workers)
    dict.finalize(threshold=freq_threshold, nwords=max_words)
    dict_path = os.path.join(output_path, "dict.txt")
    dict.save(dict_path)
    return dict_path

Sergey Edunov's avatar
Sergey Edunov committed
278

Ruty Rinott's avatar
Ruty Rinott committed
279
280
281
282
283
def build_dictionary(filenames, workers):
    d = dictionary.Dictionary()
    for filename in filenames:
        Tokenizer.add_file_to_dictionary(filename, d, tokenize_line, workers)
    return d
Sergey Edunov's avatar
Sergey Edunov committed
284

Sergey Edunov's avatar
Sergey Edunov committed
285
286

def binarize(args, filename, dict, output_prefix, lang, offset, end):
Ruty Rinott's avatar
Ruty Rinott committed
287
288
289
    ds = indexed_dataset.IndexedDatasetBuilder(
        dataset_dest_file(args, output_prefix, lang, "bin")
    )
Sergey Edunov's avatar
Sergey Edunov committed
290
291
292
293
294

    def consumer(tensor):
        ds.add_item(tensor)

    res = Tokenizer.binarize(filename, dict, consumer, offset=offset, end=end)
Ruty Rinott's avatar
Ruty Rinott committed
295
    ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
Sergey Edunov's avatar
Sergey Edunov committed
296
297
    return res

Ruty Rinott's avatar
Ruty Rinott committed
298
299
300
301
302
303
304

def binarize_with_load(args, filename, dict_path, output_prefix, lang, offset, end):
    dict = dictionary.Dictionary.load(dict_path)
    binarize(args, filename, dict, output_prefix, lang, offset, end)
    return dataset_dest_prefix(args, output_prefix, lang)


Sergey Edunov's avatar
Sergey Edunov committed
305
def dataset_dest_prefix(args, output_prefix, lang):
Ruty Rinott's avatar
Ruty Rinott committed
306
307
308
309
310
    base = f"{args.destdir}/{output_prefix}"
    lang_part = (
        f".{args.source_lang}-{args.target_lang}.{lang}" if lang is not None else ""
    )
    return f"{base}{lang_part}"
Sergey Edunov's avatar
Sergey Edunov committed
311
312
313
314


def dataset_dest_file(args, output_prefix, lang, extension):
    base = dataset_dest_prefix(args, output_prefix, lang)
Ruty Rinott's avatar
Ruty Rinott committed
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    return f"{base}.{extension}"


def get_offsets(input_file, num_workers):
    return Tokenizer.find_offsets(input_file, num_workers)


def merge_files(files, outpath):
    ds = indexed_dataset.IndexedDatasetBuilder("{}.bin".format(outpath))
    for file in files:
        ds.merge_file_(file)
        os.remove(indexed_dataset.data_file_path(file))
        os.remove(indexed_dataset.index_file_path(file))
    ds.finalize("{}.idx".format(outpath))
Sergey Edunov's avatar
Sergey Edunov committed
329
330


Ruty Rinott's avatar
Ruty Rinott committed
331
if __name__ == "__main__":
Myle Ott's avatar
Myle Ott committed
332
333
334
    parser = get_parser()
    args = parser.parse_args()
    main(args)