preprocess.py 13 KB
Newer Older
Louis Martin's avatar
Louis Martin committed
1
#!/usr/bin/env python3
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Data pre-processing: build vocabularies and binarize training data.
"""
Sergey Edunov's avatar
Sergey Edunov committed
11
12

import argparse
Sergey Edunov's avatar
Sergey Edunov committed
13
from collections import Counter
Sergey Edunov's avatar
Sergey Edunov committed
14
from itertools import zip_longest
15
16
import os
import shutil
Sergey Edunov's avatar
Sergey Edunov committed
17

Sergey Edunov's avatar
Sergey Edunov committed
18

alexeib's avatar
alexeib committed
19
from fairseq.data import indexed_dataset, dictionary
Myle Ott's avatar
Myle Ott committed
20
from fairseq.tokenizer import Tokenizer, tokenize_line
Myle Ott's avatar
Myle Ott committed
21
from multiprocessing import Pool
Sergey Edunov's avatar
Sergey Edunov committed
22

23
24
from fairseq.utils import import_user_module

Sergey Edunov's avatar
Sergey Edunov committed
25

Myle Ott's avatar
Myle Ott committed
26
def get_parser():
Myle Ott's avatar
Myle Ott committed
27
    parser = argparse.ArgumentParser()
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    # fmt: off
    parser.add_argument("-s", "--source-lang", default=None, metavar="SRC",
                        help="source language")
    parser.add_argument("-t", "--target-lang", default=None, metavar="TARGET",
                        help="target language")
    parser.add_argument("--trainpref", metavar="FP", default=None,
                        help="train file prefix")
    parser.add_argument("--validpref", metavar="FP", default=None,
                        help="comma separated, valid file prefixes")
    parser.add_argument("--testpref", metavar="FP", default=None,
                        help="comma separated, test file prefixes")
    parser.add_argument("--destdir", metavar="DIR", default="data-bin",
                        help="destination dir")
    parser.add_argument("--thresholdtgt", metavar="N", default=0, type=int,
                        help="map words appearing less than threshold times to unknown")
    parser.add_argument("--thresholdsrc", metavar="N", default=0, type=int,
                        help="map words appearing less than threshold times to unknown")
    parser.add_argument("--tgtdict", metavar="FP",
                        help="reuse given target dictionary")
    parser.add_argument("--srcdict", metavar="FP",
                        help="reuse given source dictionary")
    parser.add_argument("--nwordstgt", metavar="N", default=-1, type=int,
                        help="number of target words to retain")
    parser.add_argument("--nwordssrc", metavar="N", default=-1, type=int,
                        help="number of source words to retain")
    parser.add_argument("--alignfile", metavar="ALIGN", default=None,
                        help="an alignment file (optional)")
    parser.add_argument("--output-format", metavar="FORMAT", default="binary",
                        choices=["binary", "raw"],
                        help="output format (optional)")
    parser.add_argument("--joined-dictionary", action="store_true",
                        help="Generate joined dictionary")
    parser.add_argument("--only-source", action="store_true",
                        help="Only process the source language")
    parser.add_argument("--padding-factor", metavar="N", default=8, type=int,
                        help="Pad dictionary size to be multiple of N")
    parser.add_argument("--workers", metavar="N", default=1, type=int,
                        help="number of parallel workers")
    # fmt: on
Myle Ott's avatar
Myle Ott committed
67
    return parser
Sergey Edunov's avatar
Sergey Edunov committed
68

Myle Ott's avatar
Myle Ott committed
69

Myle Ott's avatar
Myle Ott committed
70
def main(args):
71
72
    import_user_module(args)

Sergey Edunov's avatar
Sergey Edunov committed
73
74
    print(args)
    os.makedirs(args.destdir, exist_ok=True)
75
    target = not args.only_source
Sergey Edunov's avatar
Sergey Edunov committed
76

alexeib's avatar
alexeib committed
77
    def train_path(lang):
Ruty Rinott's avatar
Ruty Rinott committed
78
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")
alexeib's avatar
alexeib committed
79
80
81
82

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
Myle Ott's avatar
Myle Ott committed
83
            fname += ".{lang}".format(lang=lang)
alexeib's avatar
alexeib committed
84
85
86
87
88
89
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
Ruty Rinott's avatar
Ruty Rinott committed
90
        return dest_path("dict", lang) + ".txt"
alexeib's avatar
alexeib committed
91

Myle Ott's avatar
Myle Ott committed
92
    if args.joined_dictionary:
Ruty Rinott's avatar
Ruty Rinott committed
93
94
95
96
97
98
        assert not args.srcdict, "cannot combine --srcdict and --joined-dictionary"
        assert not args.tgtdict, "cannot combine --tgtdict and --joined-dictionary"
        src_dict = build_dictionary(
            {train_path(lang) for lang in [args.source_lang, args.target_lang]},
            args.workers,
        )
Myle Ott's avatar
Myle Ott committed
99
        tgt_dict = src_dict
100
    else:
Myle Ott's avatar
Myle Ott committed
101
102
103
        if args.srcdict:
            src_dict = dictionary.Dictionary.load(args.srcdict)
        else:
Ruty Rinott's avatar
Ruty Rinott committed
104
105
106
107
            assert (
                args.trainpref
            ), "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)], args.workers)
108
109
110
111
        if target:
            if args.tgtdict:
                tgt_dict = dictionary.Dictionary.load(args.tgtdict)
            else:
Ruty Rinott's avatar
Ruty Rinott committed
112
113
114
115
116
117
                assert (
                    args.trainpref
                ), "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary(
                    [train_path(args.target_lang)], args.workers
                )
Myle Ott's avatar
Myle Ott committed
118

Myle Ott's avatar
Myle Ott committed
119
120
121
122
123
    src_dict.finalize(
        threshold=args.thresholdsrc,
        nwords=args.nwordssrc,
        padding_factor=args.padding_factor,
    )
alexeib's avatar
alexeib committed
124
    src_dict.save(dict_path(args.source_lang))
125
    if target:
Myle Ott's avatar
Myle Ott committed
126
127
128
129
130
131
        if not args.joined_dictionary:
            tgt_dict.finalize(
                threshold=args.thresholdtgt,
                nwords=args.nwordstgt,
                padding_factor=args.padding_factor,
            )
alexeib's avatar
alexeib committed
132
        tgt_dict.save(dict_path(args.target_lang))
Sergey Edunov's avatar
Sergey Edunov committed
133

Sergey Edunov's avatar
Sergey Edunov committed
134
    def make_binary_dataset(input_prefix, output_prefix, lang, num_workers):
alexeib's avatar
alexeib committed
135
        dict = dictionary.Dictionary.load(dict_path(lang))
Ruty Rinott's avatar
Ruty Rinott committed
136
        print("| [{}] Dictionary: {} types".format(lang, len(dict) - 1))
Sergey Edunov's avatar
Sergey Edunov committed
137
138
        n_seq_tok = [0, 0]
        replaced = Counter()
Sergey Edunov's avatar
Sergey Edunov committed
139

Sergey Edunov's avatar
Sergey Edunov committed
140
        def merge_result(worker_result):
Ruty Rinott's avatar
Ruty Rinott committed
141
142
143
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]
Sergey Edunov's avatar
Sergey Edunov committed
144

Ruty Rinott's avatar
Ruty Rinott committed
145
146
147
        input_file = "{}{}".format(
            input_prefix, ("." + lang) if lang is not None else ""
        )
Sergey Edunov's avatar
Sergey Edunov committed
148
149
150
        offsets = Tokenizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
Ruty Rinott's avatar
Ruty Rinott committed
151
            pool = Pool(processes=num_workers - 1)
Sergey Edunov's avatar
Sergey Edunov committed
152
153
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
Ruty Rinott's avatar
Ruty Rinott committed
154
155
156
157
158
159
160
161
162
163
164
165
166
                pool.apply_async(
                    binarize,
                    (
                        args,
                        input_file,
                        dict,
                        prefix,
                        lang,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                    ),
                    callback=merge_result,
                )
Sergey Edunov's avatar
Sergey Edunov committed
167
168
            pool.close()

Ruty Rinott's avatar
Ruty Rinott committed
169
170
171
172
173
174
175
176
        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, "bin")
        )
        merge_result(
            Tokenizer.binarize(
                input_file, dict, lambda t: ds.add_item(t), offset=0, end=offsets[1]
            )
        )
Sergey Edunov's avatar
Sergey Edunov committed
177
178
179
180
181
182
183
184
185
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

Ruty Rinott's avatar
Ruty Rinott committed
186
        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
Sergey Edunov's avatar
Sergey Edunov committed
187

Ruty Rinott's avatar
Ruty Rinott committed
188
189
190
191
192
193
194
195
196
197
        print(
            "| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
                lang,
                input_file,
                n_seq_tok[0],
                n_seq_tok[1],
                100 * sum(replaced.values()) / n_seq_tok[1],
                dict.unk_word,
            )
        )
Sergey Edunov's avatar
Sergey Edunov committed
198

Sergey Edunov's avatar
Sergey Edunov committed
199
    def make_dataset(input_prefix, output_prefix, lang, num_workers=1):
Ruty Rinott's avatar
Ruty Rinott committed
200
        if args.output_format == "binary":
Sergey Edunov's avatar
Sergey Edunov committed
201
            make_binary_dataset(input_prefix, output_prefix, lang, num_workers)
Ruty Rinott's avatar
Ruty Rinott committed
202
        elif args.output_format == "raw":
203
            # Copy original text file to destination folder
204
            output_text_file = dest_path(
Ruty Rinott's avatar
Ruty Rinott committed
205
                output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
206
207
                lang,
            )
alexeib's avatar
alexeib committed
208
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)
209

210
    def make_all(lang):
211
        if args.trainpref:
Ruty Rinott's avatar
Ruty Rinott committed
212
            make_dataset(args.trainpref, "train", lang, num_workers=args.workers)
213
        if args.validpref:
Ruty Rinott's avatar
Ruty Rinott committed
214
215
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
216
                make_dataset(validpref, outprefix, lang)
217
        if args.testpref:
Ruty Rinott's avatar
Ruty Rinott committed
218
219
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
220
                make_dataset(testpref, outprefix, lang)
221

222
    make_all(args.source_lang)
223
    if target:
224
        make_all(args.target_lang)
225

Ruty Rinott's avatar
Ruty Rinott committed
226
    print("| Wrote preprocessed data to {}".format(args.destdir))
Sergey Edunov's avatar
Sergey Edunov committed
227
228

    if args.alignfile:
229
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
alexeib's avatar
alexeib committed
230
231
232
233
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        src_dict = dictionary.Dictionary.load(dict_path(args.source_lang))
        tgt_dict = dictionary.Dictionary.load(dict_path(args.target_lang))
Sergey Edunov's avatar
Sergey Edunov committed
234
        freq_map = {}
Ruty Rinott's avatar
Ruty Rinott committed
235
236
237
        with open(args.alignfile, "r") as align_file:
            with open(src_file_name, "r") as src_file:
                with open(tgt_file_name, "r") as tgt_file:
Sergey Edunov's avatar
Sergey Edunov committed
238
239
240
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = Tokenizer.tokenize(s, src_dict, add_if_not_exist=False)
                        ti = Tokenizer.tokenize(t, tgt_dict, add_if_not_exist=False)
Ruty Rinott's avatar
Ruty Rinott committed
241
                        ai = list(map(lambda x: tuple(x.split("-")), a.split()))
Sergey Edunov's avatar
Sergey Edunov committed
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)

Ruty Rinott's avatar
Ruty Rinott committed
262
263
264
265
266
267
268
        with open(
            os.path.join(
                args.destdir,
                "alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
            ),
            "w",
        ) as f:
Sergey Edunov's avatar
Sergey Edunov committed
269
            for k, v in align_dict.items():
Ruty Rinott's avatar
Ruty Rinott committed
270
271
272
273
274
275
276
277
278
279
280
281
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)


def build_and_save_dictionary(
    train_path, output_path, num_workers, freq_threshold, max_words
):
    dict = build_dictionary([train_path], num_workers)
    dict.finalize(threshold=freq_threshold, nwords=max_words)
    dict_path = os.path.join(output_path, "dict.txt")
    dict.save(dict_path)
    return dict_path

Sergey Edunov's avatar
Sergey Edunov committed
282

Ruty Rinott's avatar
Ruty Rinott committed
283
284
285
286
287
def build_dictionary(filenames, workers):
    d = dictionary.Dictionary()
    for filename in filenames:
        Tokenizer.add_file_to_dictionary(filename, d, tokenize_line, workers)
    return d
Sergey Edunov's avatar
Sergey Edunov committed
288

Sergey Edunov's avatar
Sergey Edunov committed
289
290

def binarize(args, filename, dict, output_prefix, lang, offset, end):
Ruty Rinott's avatar
Ruty Rinott committed
291
292
293
    ds = indexed_dataset.IndexedDatasetBuilder(
        dataset_dest_file(args, output_prefix, lang, "bin")
    )
Sergey Edunov's avatar
Sergey Edunov committed
294
295
296
297
298

    def consumer(tensor):
        ds.add_item(tensor)

    res = Tokenizer.binarize(filename, dict, consumer, offset=offset, end=end)
Ruty Rinott's avatar
Ruty Rinott committed
299
    ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
Sergey Edunov's avatar
Sergey Edunov committed
300
301
    return res

Ruty Rinott's avatar
Ruty Rinott committed
302
303
304
305
306
307
308

def binarize_with_load(args, filename, dict_path, output_prefix, lang, offset, end):
    dict = dictionary.Dictionary.load(dict_path)
    binarize(args, filename, dict, output_prefix, lang, offset, end)
    return dataset_dest_prefix(args, output_prefix, lang)


Sergey Edunov's avatar
Sergey Edunov committed
309
def dataset_dest_prefix(args, output_prefix, lang):
Ruty Rinott's avatar
Ruty Rinott committed
310
311
312
313
314
    base = f"{args.destdir}/{output_prefix}"
    lang_part = (
        f".{args.source_lang}-{args.target_lang}.{lang}" if lang is not None else ""
    )
    return f"{base}{lang_part}"
Sergey Edunov's avatar
Sergey Edunov committed
315
316
317
318


def dataset_dest_file(args, output_prefix, lang, extension):
    base = dataset_dest_prefix(args, output_prefix, lang)
Ruty Rinott's avatar
Ruty Rinott committed
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    return f"{base}.{extension}"


def get_offsets(input_file, num_workers):
    return Tokenizer.find_offsets(input_file, num_workers)


def merge_files(files, outpath):
    ds = indexed_dataset.IndexedDatasetBuilder("{}.bin".format(outpath))
    for file in files:
        ds.merge_file_(file)
        os.remove(indexed_dataset.data_file_path(file))
        os.remove(indexed_dataset.index_file_path(file))
    ds.finalize("{}.idx".format(outpath))
Sergey Edunov's avatar
Sergey Edunov committed
333
334


Ruty Rinott's avatar
Ruty Rinott committed
335
if __name__ == "__main__":
Myle Ott's avatar
Myle Ott committed
336
337
338
    parser = get_parser()
    args = parser.parse_args()
    main(args)