preprocess.py 10.4 KB
Newer Older
Louis Martin's avatar
Louis Martin committed
1
#!/usr/bin/env python3
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Data pre-processing: build vocabularies and binarize training data.
"""
Sergey Edunov's avatar
Sergey Edunov committed
11

Sergey Edunov's avatar
Sergey Edunov committed
12
from collections import Counter
Sergey Edunov's avatar
Sergey Edunov committed
13
14
from itertools import zip_longest

15
16
from fairseq import options, tasks
from fairseq.data import indexed_dataset
17
18
from fairseq.binarizer import Binarizer
from fairseq.utils import import_user_module
Myle Ott's avatar
Myle Ott committed
19
from multiprocessing import Pool
Sergey Edunov's avatar
Sergey Edunov committed
20

21
22
import os
import shutil
23

Sergey Edunov's avatar
Sergey Edunov committed
24

Myle Ott's avatar
Myle Ott committed
25
def main(args):
26
27
    import_user_module(args)

Sergey Edunov's avatar
Sergey Edunov committed
28
    print(args)
29

Sergey Edunov's avatar
Sergey Edunov committed
30
    os.makedirs(args.destdir, exist_ok=True)
31
    target = not args.only_source
Sergey Edunov's avatar
Sergey Edunov committed
32

33
34
    task = tasks.get_task(args.task)

alexeib's avatar
alexeib committed
35
    def train_path(lang):
Ruty Rinott's avatar
Ruty Rinott committed
36
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")
alexeib's avatar
alexeib committed
37
38
39
40

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
Myle Ott's avatar
Myle Ott committed
41
            fname += ".{lang}".format(lang=lang)
alexeib's avatar
alexeib committed
42
43
44
45
46
47
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
Ruty Rinott's avatar
Ruty Rinott committed
48
        return dest_path("dict", lang) + ".txt"
alexeib's avatar
alexeib committed
49

50
51
52
53
54
55
56
57
    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
Ruty Rinott's avatar
Ruty Rinott committed
58
        )
59

60
61
62
63
64
    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))
    if target and not args.tgtdict and os.path.exists(dict_path(args.target_lang)):
        raise FileExistsError(dict_path(args.target_lang))

65
    if args.joined_dictionary:
66
67
        assert not args.srcdict or not args.tgtdict, \
            "cannot use both --srcdict and --tgtdict with --joined-dictionary"
68
69
70
71
72
73

        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        elif args.tgtdict:
            src_dict = task.load_dictionary(args.tgtdict)
        else:
74
75
76
77
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary(
                {train_path(lang) for lang in [args.source_lang, args.target_lang]}, src=True
            )
Myle Ott's avatar
Myle Ott committed
78
        tgt_dict = src_dict
79
    else:
Myle Ott's avatar
Myle Ott committed
80
        if args.srcdict:
81
            src_dict = task.load_dictionary(args.srcdict)
Myle Ott's avatar
Myle Ott committed
82
        else:
83
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
84
85
            src_dict = build_dictionary([train_path(args.source_lang)], src=True)

86
87
        if target:
            if args.tgtdict:
88
                tgt_dict = task.load_dictionary(args.tgtdict)
89
            else:
90
                assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
91
92
93
                tgt_dict = build_dictionary([train_path(args.target_lang)], tgt=True)
        else:
            tgt_dict = None
Myle Ott's avatar
Myle Ott committed
94

alexeib's avatar
alexeib committed
95
    src_dict.save(dict_path(args.source_lang))
96
    if target and tgt_dict is not None:
alexeib's avatar
alexeib committed
97
        tgt_dict.save(dict_path(args.target_lang))
Sergey Edunov's avatar
Sergey Edunov committed
98

99
100
    def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
Sergey Edunov's avatar
Sergey Edunov committed
101
102
        n_seq_tok = [0, 0]
        replaced = Counter()
Sergey Edunov's avatar
Sergey Edunov committed
103

Sergey Edunov's avatar
Sergey Edunov committed
104
        def merge_result(worker_result):
Ruty Rinott's avatar
Ruty Rinott committed
105
106
107
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]
Sergey Edunov's avatar
Sergey Edunov committed
108

Ruty Rinott's avatar
Ruty Rinott committed
109
110
111
        input_file = "{}{}".format(
            input_prefix, ("." + lang) if lang is not None else ""
        )
112
        offsets = Binarizer.find_offsets(input_file, num_workers)
Sergey Edunov's avatar
Sergey Edunov committed
113
114
        pool = None
        if num_workers > 1:
Ruty Rinott's avatar
Ruty Rinott committed
115
            pool = Pool(processes=num_workers - 1)
Sergey Edunov's avatar
Sergey Edunov committed
116
117
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
Ruty Rinott's avatar
Ruty Rinott committed
118
119
120
121
122
                pool.apply_async(
                    binarize,
                    (
                        args,
                        input_file,
123
                        vocab,
Ruty Rinott's avatar
Ruty Rinott committed
124
125
126
                        prefix,
                        lang,
                        offsets[worker_id],
127
                        offsets[worker_id + 1]
Ruty Rinott's avatar
Ruty Rinott committed
128
                    ),
129
                    callback=merge_result
Ruty Rinott's avatar
Ruty Rinott committed
130
                )
Sergey Edunov's avatar
Sergey Edunov committed
131
132
            pool.close()

Ruty Rinott's avatar
Ruty Rinott committed
133
134
135
136
        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, "bin")
        )
        merge_result(
137
138
139
            Binarizer.binarize(
                input_file, vocab, lambda t: ds.add_item(t),
                offset=0, end=offsets[1]
Ruty Rinott's avatar
Ruty Rinott committed
140
141
            )
        )
Sergey Edunov's avatar
Sergey Edunov committed
142
143
144
145
146
147
148
149
150
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

Ruty Rinott's avatar
Ruty Rinott committed
151
        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
Sergey Edunov's avatar
Sergey Edunov committed
152

Ruty Rinott's avatar
Ruty Rinott committed
153
154
155
156
157
158
159
        print(
            "| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
                lang,
                input_file,
                n_seq_tok[0],
                n_seq_tok[1],
                100 * sum(replaced.values()) / n_seq_tok[1],
160
                vocab.unk_word,
Ruty Rinott's avatar
Ruty Rinott committed
161
162
            )
        )
Sergey Edunov's avatar
Sergey Edunov committed
163

164
    def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
Ruty Rinott's avatar
Ruty Rinott committed
165
        if args.output_format == "binary":
166
            make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers)
Ruty Rinott's avatar
Ruty Rinott committed
167
        elif args.output_format == "raw":
168
            # Copy original text file to destination folder
169
            output_text_file = dest_path(
Ruty Rinott's avatar
Ruty Rinott committed
170
                output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
171
172
                lang,
            )
alexeib's avatar
alexeib committed
173
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)
174

175
    def make_all(lang, vocab):
176
        if args.trainpref:
177
            make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers)
178
        if args.validpref:
Ruty Rinott's avatar
Ruty Rinott committed
179
180
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
181
                make_dataset(vocab, validpref, outprefix, lang, num_workers=args.workers)
182
        if args.testpref:
Ruty Rinott's avatar
Ruty Rinott committed
183
184
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
185
                make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers)
186

187
    make_all(args.source_lang, src_dict)
188
    if target:
189
        make_all(args.target_lang, tgt_dict)
190

Ruty Rinott's avatar
Ruty Rinott committed
191
    print("| Wrote preprocessed data to {}".format(args.destdir))
Sergey Edunov's avatar
Sergey Edunov committed
192
193

    if args.alignfile:
194
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
alexeib's avatar
alexeib committed
195
196
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
Sergey Edunov's avatar
Sergey Edunov committed
197
        freq_map = {}
198
199
200
        with open(args.alignfile, "r", encoding='utf-8') as align_file:
            with open(src_file_name, "r", encoding='utf-8') as src_file:
                with open(tgt_file_name, "r", encoding='utf-8') as tgt_file:
Sergey Edunov's avatar
Sergey Edunov committed
201
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
202
203
                        si = src_dict.encode_line(s, add_if_not_exist=False)
                        ti = tgt_dict.encode_line(t, add_if_not_exist=False)
Ruty Rinott's avatar
Ruty Rinott committed
204
                        ai = list(map(lambda x: tuple(x.split("-")), a.split()))
Sergey Edunov's avatar
Sergey Edunov committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)

Ruty Rinott's avatar
Ruty Rinott committed
225
        with open(
226
227
228
229
230
                os.path.join(
                    args.destdir,
                    "alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
                ),
                "w", encoding='utf-8'
Ruty Rinott's avatar
Ruty Rinott committed
231
        ) as f:
Sergey Edunov's avatar
Sergey Edunov committed
232
            for k, v in align_dict.items():
Ruty Rinott's avatar
Ruty Rinott committed
233
234
235
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)


236
def binarize(args, filename, vocab, output_prefix, lang, offset, end, append_eos=True):
Ruty Rinott's avatar
Ruty Rinott committed
237
238
239
    ds = indexed_dataset.IndexedDatasetBuilder(
        dataset_dest_file(args, output_prefix, lang, "bin")
    )
Sergey Edunov's avatar
Sergey Edunov committed
240
241
242
243

    def consumer(tensor):
        ds.add_item(tensor)

244
245
    res = Binarizer.binarize(filename, vocab, consumer, append_eos=append_eos,
                             offset=offset, end=end)
Ruty Rinott's avatar
Ruty Rinott committed
246
    ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
Sergey Edunov's avatar
Sergey Edunov committed
247
248
    return res

Ruty Rinott's avatar
Ruty Rinott committed
249

Sergey Edunov's avatar
Sergey Edunov committed
250
def dataset_dest_prefix(args, output_prefix, lang):
251
    base = "{}/{}".format(args.destdir, output_prefix)
Ruty Rinott's avatar
Ruty Rinott committed
252
    lang_part = (
253
        ".{}-{}.{}".format(args.source_lang, args.target_lang, lang) if lang is not None else ""
Ruty Rinott's avatar
Ruty Rinott committed
254
    )
255
    return "{}{}".format(base, lang_part)
Sergey Edunov's avatar
Sergey Edunov committed
256
257
258
259


def dataset_dest_file(args, output_prefix, lang, extension):
    base = dataset_dest_prefix(args, output_prefix, lang)
260
    return "{}.{}".format(base, extension)
Ruty Rinott's avatar
Ruty Rinott committed
261
262
263


def get_offsets(input_file, num_workers):
264
    return Binarizer.find_offsets(input_file, num_workers)
Ruty Rinott's avatar
Ruty Rinott committed
265
266
267
268
269
270
271
272
273


def merge_files(files, outpath):
    ds = indexed_dataset.IndexedDatasetBuilder("{}.bin".format(outpath))
    for file in files:
        ds.merge_file_(file)
        os.remove(indexed_dataset.data_file_path(file))
        os.remove(indexed_dataset.index_file_path(file))
    ds.finalize("{}.idx".format(outpath))
Sergey Edunov's avatar
Sergey Edunov committed
274
275


Myle Ott's avatar
Myle Ott committed
276
def cli_main():
277
    parser = options.get_preprocessing_parser()
Myle Ott's avatar
Myle Ott committed
278
279
    args = parser.parse_args()
    main(args)
Myle Ott's avatar
Myle Ott committed
280
281
282
283


if __name__ == "__main__":
    cli_main()