preprocess.py 10.1 KB
Newer Older
Louis Martin's avatar
Louis Martin committed
1
#!/usr/bin/env python3
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Data pre-processing: build vocabularies and binarize training data.
"""
Sergey Edunov's avatar
Sergey Edunov committed
11

Sergey Edunov's avatar
Sergey Edunov committed
12
from collections import Counter
Sergey Edunov's avatar
Sergey Edunov committed
13
from itertools import zip_longest
14
15
import os
import shutil
Sergey Edunov's avatar
Sergey Edunov committed
16

17
18
19
from fairseq import options, tasks
from fairseq.data import indexed_dataset
from fairseq.tokenizer import Tokenizer
Myle Ott's avatar
Myle Ott committed
20
from multiprocessing import Pool
Sergey Edunov's avatar
Sergey Edunov committed
21

22
23
from fairseq.utils import import_user_module

Sergey Edunov's avatar
Sergey Edunov committed
24

Myle Ott's avatar
Myle Ott committed
25
def main(args):
26
27
    import_user_module(args)

Sergey Edunov's avatar
Sergey Edunov committed
28
29
    print(args)
    os.makedirs(args.destdir, exist_ok=True)
30
    target = not args.only_source
Sergey Edunov's avatar
Sergey Edunov committed
31

32
33
    task = tasks.get_task(args.task)

alexeib's avatar
alexeib committed
34
    def train_path(lang):
Ruty Rinott's avatar
Ruty Rinott committed
35
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")
alexeib's avatar
alexeib committed
36
37
38
39

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
Myle Ott's avatar
Myle Ott committed
40
            fname += ".{lang}".format(lang=lang)
alexeib's avatar
alexeib committed
41
42
43
44
45
46
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
Ruty Rinott's avatar
Ruty Rinott committed
47
        return dest_path("dict", lang) + ".txt"
alexeib's avatar
alexeib committed
48

49
50
51
52
53
54
55
56
    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
Ruty Rinott's avatar
Ruty Rinott committed
57
        )
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

    if args.joined_dictionary:
        assert (
                not args.srcdict or not args.tgtdict
        ), "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        elif args.tgtdict:
            src_dict = task.load_dictionary(args.tgtdict)
        else:
            assert (
                args.trainpref
            ), "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary({train_path(lang) for lang in [args.source_lang, args.target_lang]}, src=True)
Myle Ott's avatar
Myle Ott committed
73
        tgt_dict = src_dict
74
    else:
Myle Ott's avatar
Myle Ott committed
75
        if args.srcdict:
76
            src_dict = task.load_dictionary(args.srcdict)
Myle Ott's avatar
Myle Ott committed
77
        else:
Ruty Rinott's avatar
Ruty Rinott committed
78
79
80
            assert (
                args.trainpref
            ), "--trainpref must be set if --srcdict is not specified"
81
82
            src_dict = build_dictionary([train_path(args.source_lang)], src=True)

83
84
        if target:
            if args.tgtdict:
85
                tgt_dict = task.load_dictionary(args.tgtdict)
86
            else:
Ruty Rinott's avatar
Ruty Rinott committed
87
88
89
                assert (
                    args.trainpref
                ), "--trainpref must be set if --tgtdict is not specified"
90
91
92
                tgt_dict = build_dictionary([train_path(args.target_lang)], tgt=True)
        else:
            tgt_dict = None
Myle Ott's avatar
Myle Ott committed
93

alexeib's avatar
alexeib committed
94
    src_dict.save(dict_path(args.source_lang))
95
    if target and tgt_dict is not None:
alexeib's avatar
alexeib committed
96
        tgt_dict.save(dict_path(args.target_lang))
Sergey Edunov's avatar
Sergey Edunov committed
97

Sergey Edunov's avatar
Sergey Edunov committed
98
    def make_binary_dataset(input_prefix, output_prefix, lang, num_workers):
99
        dict = task.load_dictionary(dict_path(lang))
Ruty Rinott's avatar
Ruty Rinott committed
100
        print("| [{}] Dictionary: {} types".format(lang, len(dict) - 1))
Sergey Edunov's avatar
Sergey Edunov committed
101
102
        n_seq_tok = [0, 0]
        replaced = Counter()
Sergey Edunov's avatar
Sergey Edunov committed
103

Sergey Edunov's avatar
Sergey Edunov committed
104
        def merge_result(worker_result):
Ruty Rinott's avatar
Ruty Rinott committed
105
106
107
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]
Sergey Edunov's avatar
Sergey Edunov committed
108

Ruty Rinott's avatar
Ruty Rinott committed
109
110
111
        input_file = "{}{}".format(
            input_prefix, ("." + lang) if lang is not None else ""
        )
Sergey Edunov's avatar
Sergey Edunov committed
112
113
114
        offsets = Tokenizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
Ruty Rinott's avatar
Ruty Rinott committed
115
            pool = Pool(processes=num_workers - 1)
Sergey Edunov's avatar
Sergey Edunov committed
116
117
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
Ruty Rinott's avatar
Ruty Rinott committed
118
119
120
121
122
123
124
125
126
127
128
129
130
                pool.apply_async(
                    binarize,
                    (
                        args,
                        input_file,
                        dict,
                        prefix,
                        lang,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                    ),
                    callback=merge_result,
                )
Sergey Edunov's avatar
Sergey Edunov committed
131
132
            pool.close()

Ruty Rinott's avatar
Ruty Rinott committed
133
134
135
136
137
138
139
140
        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, "bin")
        )
        merge_result(
            Tokenizer.binarize(
                input_file, dict, lambda t: ds.add_item(t), offset=0, end=offsets[1]
            )
        )
Sergey Edunov's avatar
Sergey Edunov committed
141
142
143
144
145
146
147
148
149
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

Ruty Rinott's avatar
Ruty Rinott committed
150
        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
Sergey Edunov's avatar
Sergey Edunov committed
151

Ruty Rinott's avatar
Ruty Rinott committed
152
153
154
155
156
157
158
159
160
161
        print(
            "| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
                lang,
                input_file,
                n_seq_tok[0],
                n_seq_tok[1],
                100 * sum(replaced.values()) / n_seq_tok[1],
                dict.unk_word,
            )
        )
Sergey Edunov's avatar
Sergey Edunov committed
162

Sergey Edunov's avatar
Sergey Edunov committed
163
    def make_dataset(input_prefix, output_prefix, lang, num_workers=1):
Ruty Rinott's avatar
Ruty Rinott committed
164
        if args.output_format == "binary":
Sergey Edunov's avatar
Sergey Edunov committed
165
            make_binary_dataset(input_prefix, output_prefix, lang, num_workers)
Ruty Rinott's avatar
Ruty Rinott committed
166
        elif args.output_format == "raw":
167
            # Copy original text file to destination folder
168
            output_text_file = dest_path(
Ruty Rinott's avatar
Ruty Rinott committed
169
                output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
170
171
                lang,
            )
alexeib's avatar
alexeib committed
172
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)
173

174
    def make_all(lang):
175
        if args.trainpref:
Ruty Rinott's avatar
Ruty Rinott committed
176
            make_dataset(args.trainpref, "train", lang, num_workers=args.workers)
177
        if args.validpref:
Ruty Rinott's avatar
Ruty Rinott committed
178
179
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
180
                make_dataset(validpref, outprefix, lang)
181
        if args.testpref:
Ruty Rinott's avatar
Ruty Rinott committed
182
183
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
184
                make_dataset(testpref, outprefix, lang)
185

186
    make_all(args.source_lang)
187
    if target:
188
        make_all(args.target_lang)
189

Ruty Rinott's avatar
Ruty Rinott committed
190
    print("| Wrote preprocessed data to {}".format(args.destdir))
Sergey Edunov's avatar
Sergey Edunov committed
191
192

    if args.alignfile:
193
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
alexeib's avatar
alexeib committed
194
195
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
Sergey Edunov's avatar
Sergey Edunov committed
196
        freq_map = {}
197
198
199
        with open(args.alignfile, "r", encoding='utf-8') as align_file:
            with open(src_file_name, "r", encoding='utf-8') as src_file:
                with open(tgt_file_name, "r", encoding='utf-8') as tgt_file:
Sergey Edunov's avatar
Sergey Edunov committed
200
201
202
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = Tokenizer.tokenize(s, src_dict, add_if_not_exist=False)
                        ti = Tokenizer.tokenize(t, tgt_dict, add_if_not_exist=False)
Ruty Rinott's avatar
Ruty Rinott committed
203
                        ai = list(map(lambda x: tuple(x.split("-")), a.split()))
Sergey Edunov's avatar
Sergey Edunov committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)

Ruty Rinott's avatar
Ruty Rinott committed
224
        with open(
225
226
227
228
229
                os.path.join(
                    args.destdir,
                    "alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
                ),
                "w", encoding='utf-8'
Ruty Rinott's avatar
Ruty Rinott committed
230
        ) as f:
Sergey Edunov's avatar
Sergey Edunov committed
231
            for k, v in align_dict.items():
Ruty Rinott's avatar
Ruty Rinott committed
232
233
234
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)


Sergey Edunov's avatar
Sergey Edunov committed
235
def binarize(args, filename, dict, output_prefix, lang, offset, end):
Ruty Rinott's avatar
Ruty Rinott committed
236
237
238
    ds = indexed_dataset.IndexedDatasetBuilder(
        dataset_dest_file(args, output_prefix, lang, "bin")
    )
Sergey Edunov's avatar
Sergey Edunov committed
239
240
241
242
243

    def consumer(tensor):
        ds.add_item(tensor)

    res = Tokenizer.binarize(filename, dict, consumer, offset=offset, end=end)
Ruty Rinott's avatar
Ruty Rinott committed
244
    ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
Sergey Edunov's avatar
Sergey Edunov committed
245
246
    return res

Ruty Rinott's avatar
Ruty Rinott committed
247

Sergey Edunov's avatar
Sergey Edunov committed
248
def dataset_dest_prefix(args, output_prefix, lang):
249
    base = "{}/{}".format(args.destdir, output_prefix)
Ruty Rinott's avatar
Ruty Rinott committed
250
    lang_part = (
251
        ".{}-{}.{}".format(args.source_lang, args.target_lang, lang) if lang is not None else ""
Ruty Rinott's avatar
Ruty Rinott committed
252
    )
253
    return "{}{}".format(base, lang_part)
Sergey Edunov's avatar
Sergey Edunov committed
254
255
256
257


def dataset_dest_file(args, output_prefix, lang, extension):
    base = dataset_dest_prefix(args, output_prefix, lang)
258
    return "{}.{}".format(base, extension)
Ruty Rinott's avatar
Ruty Rinott committed
259
260
261
262
263
264
265
266
267
268
269
270
271


def get_offsets(input_file, num_workers):
    return Tokenizer.find_offsets(input_file, num_workers)


def merge_files(files, outpath):
    ds = indexed_dataset.IndexedDatasetBuilder("{}.bin".format(outpath))
    for file in files:
        ds.merge_file_(file)
        os.remove(indexed_dataset.data_file_path(file))
        os.remove(indexed_dataset.index_file_path(file))
    ds.finalize("{}.idx".format(outpath))
Sergey Edunov's avatar
Sergey Edunov committed
272
273


Ruty Rinott's avatar
Ruty Rinott committed
274
if __name__ == "__main__":
275
    parser = options.get_preprocessing_parser()
Myle Ott's avatar
Myle Ott committed
276
277
    args = parser.parse_args()
    main(args)