Commit bbb4120b authored by Davide Caroselli's avatar Davide Caroselli Committed by Facebook Github Bot
Browse files

Support custom Dictionary implementations in 'preprocess.py' (#448)

Summary:
The `preprocess.py` script has been refactored in order to:

1. Use the `options` module for command line arguments  parsing. This will give to `preprocess.py` the ability to load custom modules with `--user-dir` flag (already implemented to all other binaries)
2. Dictionary loading and building code has moved to Task implementation. This allows custom Dictionary classes to be used during the data generation step.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/448

Differential Revision: D13674819

Pulled By: myleott

fbshipit-source-id: b40648a98ed6c08284577e5ec25876e018d8c822
parent ec6f8ef9
...@@ -17,6 +17,12 @@ from fairseq.tasks import TASK_REGISTRY ...@@ -17,6 +17,12 @@ from fairseq.tasks import TASK_REGISTRY
from fairseq.utils import import_user_module from fairseq.utils import import_user_module
def get_preprocessing_parser(default_task='translation'):
parser = get_parser('Preprocessing', default_task)
add_preprocess_args(parser)
return parser
def get_training_parser(default_task='translation'): def get_training_parser(default_task='translation'):
parser = get_parser('Trainer', default_task) parser = get_parser('Trainer', default_task)
add_dataset_args(parser, train=True) add_dataset_args(parser, train=True)
...@@ -142,7 +148,7 @@ def get_parser(desc, default_task='translation'): ...@@ -142,7 +148,7 @@ def get_parser(desc, default_task='translation'):
parser.add_argument('--fp16', action='store_true', help='use FP16') parser.add_argument('--fp16', action='store_true', help='use FP16')
parser.add_argument('--memory-efficient-fp16', action='store_true', parser.add_argument('--memory-efficient-fp16', action='store_true',
help='use a memory-efficient version of FP16 training; implies --fp16') help='use a memory-efficient version of FP16 training; implies --fp16')
parser.add_argument('--fp16-init-scale', default=2**7, type=int, parser.add_argument('--fp16-init-scale', default=2 ** 7, type=int,
help='default FP16 loss scale') help='default FP16 loss scale')
parser.add_argument('--fp16-scale-window', type=int, parser.add_argument('--fp16-scale-window', type=int,
help='number of updates before increasing loss scale') help='number of updates before increasing loss scale')
...@@ -159,6 +165,50 @@ def get_parser(desc, default_task='translation'): ...@@ -159,6 +165,50 @@ def get_parser(desc, default_task='translation'):
return parser return parser
def add_preprocess_args(parser):
group = parser.add_argument_group('Preprocessing')
# fmt: off
group.add_argument("-s", "--source-lang", default=None, metavar="SRC",
help="source language")
group.add_argument("-t", "--target-lang", default=None, metavar="TARGET",
help="target language")
group.add_argument("--trainpref", metavar="FP", default=None,
help="train file prefix")
group.add_argument("--validpref", metavar="FP", default=None,
help="comma separated, valid file prefixes")
group.add_argument("--testpref", metavar="FP", default=None,
help="comma separated, test file prefixes")
group.add_argument("--destdir", metavar="DIR", default="data-bin",
help="destination dir")
group.add_argument("--thresholdtgt", metavar="N", default=0, type=int,
help="map words appearing less than threshold times to unknown")
group.add_argument("--thresholdsrc", metavar="N", default=0, type=int,
help="map words appearing less than threshold times to unknown")
group.add_argument("--tgtdict", metavar="FP",
help="reuse given target dictionary")
group.add_argument("--srcdict", metavar="FP",
help="reuse given source dictionary")
group.add_argument("--nwordstgt", metavar="N", default=-1, type=int,
help="number of target words to retain")
group.add_argument("--nwordssrc", metavar="N", default=-1, type=int,
help="number of source words to retain")
group.add_argument("--alignfile", metavar="ALIGN", default=None,
help="an alignment file (optional)")
group.add_argument("--output-format", metavar="FORMAT", default="binary",
choices=["binary", "raw"],
help="output format (optional)")
group.add_argument("--joined-dictionary", action="store_true",
help="Generate joined dictionary")
group.add_argument("--only-source", action="store_true",
help="Only process the source language")
group.add_argument("--padding-factor", metavar="N", default=8, type=int,
help="Pad dictionary size to be multiple of N")
group.add_argument("--workers", metavar="N", default=1, type=int,
help="number of parallel workers")
# fmt: on
return parser
def add_dataset_args(parser, train=False, gen=False): def add_dataset_args(parser, train=False, gen=False):
group = parser.add_argument_group('Dataset and data loading') group = parser.add_argument_group('Dataset and data loading')
# fmt: off # fmt: off
......
...@@ -11,7 +11,6 @@ import os ...@@ -11,7 +11,6 @@ import os
from .fairseq_task import FairseqTask from .fairseq_task import FairseqTask
TASK_REGISTRY = {} TASK_REGISTRY = {}
TASK_CLASS_NAMES = set() TASK_CLASS_NAMES = set()
...@@ -73,3 +72,7 @@ for file in os.listdir(os.path.dirname(__file__)): ...@@ -73,3 +72,7 @@ for file in os.listdir(os.path.dirname(__file__)):
group_args = parser.add_argument_group('Additional command-line arguments') group_args = parser.add_argument_group('Additional command-line arguments')
TASK_REGISTRY[task_name].add_args(group_args) TASK_REGISTRY[task_name].add_args(group_args)
globals()[task_name + '_parser'] = parser globals()[task_name + '_parser'] = parser
def get_task(name):
return TASK_REGISTRY[name]
...@@ -5,9 +5,12 @@ ...@@ -5,9 +5,12 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from fairseq.data import data_utils, FairseqDataset, iterators
import torch import torch
from fairseq import tokenizer
from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary
from fairseq.tokenizer import Tokenizer
class FairseqTask(object): class FairseqTask(object):
""" """
...@@ -24,6 +27,35 @@ class FairseqTask(object): ...@@ -24,6 +27,35 @@ class FairseqTask(object):
self.args = args self.args = args
self.datasets = {} self.datasets = {}
@classmethod
def load_dictionary(cls, filename):
"""Load the dictionary from the filename
Args:
filename (str): the filename
"""
return Dictionary.load(filename)
@classmethod
def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8):
"""Build the dictionary
Args:
filenames (list): list of filenames
workers (int): number of concurrent workers
threshold (int): defines the minimum word count
nwords (int): defines the total number of words in the final dictionary,
including special symbols
padding_factor (int): can be used to pad the dictionary size to be a
multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores).
"""
d = Dictionary()
for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers)
d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
return d
@classmethod @classmethod
def setup_task(cls, args, **kwargs): def setup_task(cls, args, **kwargs):
"""Setup the task (e.g., load dictionaries). """Setup the task (e.g., load dictionaries).
...@@ -59,9 +91,9 @@ class FairseqTask(object): ...@@ -59,9 +91,9 @@ class FairseqTask(object):
return self.datasets[split] return self.datasets[split]
def get_batch_iterator( def get_batch_iterator(
self, dataset, max_tokens=None, max_sentences=None, max_positions=None, self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
ignore_invalid_inputs=False, required_batch_size_multiple=1, ignore_invalid_inputs=False, required_batch_size_multiple=1,
seed=1, num_shards=1, shard_id=0, num_workers=0, seed=1, num_shards=1, shard_id=0, num_workers=0,
): ):
""" """
Get an iterator that yields batches of data from the given dataset. Get an iterator that yields batches of data from the given dataset.
......
...@@ -109,8 +109,8 @@ class TranslationTask(FairseqTask): ...@@ -109,8 +109,8 @@ class TranslationTask(FairseqTask):
raise Exception('Could not infer language pair, please provide it explicitly') raise Exception('Could not infer language pair, please provide it explicitly')
# load dictionaries # load dictionaries
src_dict = Dictionary.load(os.path.join(args.data[0], 'dict.{}.txt'.format(args.source_lang))) src_dict = cls.load_dictionary(os.path.join(args.data[0], 'dict.{}.txt'.format(args.source_lang)))
tgt_dict = Dictionary.load(os.path.join(args.data[0], 'dict.{}.txt'.format(args.target_lang))) tgt_dict = cls.load_dictionary(os.path.join(args.data[0], 'dict.{}.txt'.format(args.target_lang)))
assert src_dict.pad() == tgt_dict.pad() assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos() assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk() assert src_dict.unk() == tgt_dict.unk()
......
...@@ -9,64 +9,19 @@ ...@@ -9,64 +9,19 @@
Data pre-processing: build vocabularies and binarize training data. Data pre-processing: build vocabularies and binarize training data.
""" """
import argparse
from collections import Counter from collections import Counter
from itertools import zip_longest from itertools import zip_longest
import os import os
import shutil import shutil
from fairseq import options, tasks
from fairseq.data import indexed_dataset, dictionary from fairseq.data import indexed_dataset
from fairseq.tokenizer import Tokenizer, tokenize_line from fairseq.tokenizer import Tokenizer
from multiprocessing import Pool from multiprocessing import Pool
from fairseq.utils import import_user_module from fairseq.utils import import_user_module
def get_parser():
parser = argparse.ArgumentParser()
# fmt: off
parser.add_argument("-s", "--source-lang", default=None, metavar="SRC",
help="source language")
parser.add_argument("-t", "--target-lang", default=None, metavar="TARGET",
help="target language")
parser.add_argument("--trainpref", metavar="FP", default=None,
help="train file prefix")
parser.add_argument("--validpref", metavar="FP", default=None,
help="comma separated, valid file prefixes")
parser.add_argument("--testpref", metavar="FP", default=None,
help="comma separated, test file prefixes")
parser.add_argument("--destdir", metavar="DIR", default="data-bin",
help="destination dir")
parser.add_argument("--thresholdtgt", metavar="N", default=0, type=int,
help="map words appearing less than threshold times to unknown")
parser.add_argument("--thresholdsrc", metavar="N", default=0, type=int,
help="map words appearing less than threshold times to unknown")
parser.add_argument("--tgtdict", metavar="FP",
help="reuse given target dictionary")
parser.add_argument("--srcdict", metavar="FP",
help="reuse given source dictionary")
parser.add_argument("--nwordstgt", metavar="N", default=-1, type=int,
help="number of target words to retain")
parser.add_argument("--nwordssrc", metavar="N", default=-1, type=int,
help="number of source words to retain")
parser.add_argument("--alignfile", metavar="ALIGN", default=None,
help="an alignment file (optional)")
parser.add_argument("--output-format", metavar="FORMAT", default="binary",
choices=["binary", "raw"],
help="output format (optional)")
parser.add_argument("--joined-dictionary", action="store_true",
help="Generate joined dictionary")
parser.add_argument("--only-source", action="store_true",
help="Only process the source language")
parser.add_argument("--padding-factor", metavar="N", default=8, type=int,
help="Pad dictionary size to be multiple of N")
parser.add_argument("--workers", metavar="N", default=1, type=int,
help="number of parallel workers")
# fmt: on
return parser
def main(args): def main(args):
import_user_module(args) import_user_module(args)
...@@ -74,6 +29,8 @@ def main(args): ...@@ -74,6 +29,8 @@ def main(args):
os.makedirs(args.destdir, exist_ok=True) os.makedirs(args.destdir, exist_ok=True)
target = not args.only_source target = not args.only_source
task = tasks.get_task(args.task)
def train_path(lang): def train_path(lang):
return "{}{}".format(args.trainpref, ("." + lang) if lang else "") return "{}{}".format(args.trainpref, ("." + lang) if lang else "")
...@@ -89,50 +46,57 @@ def main(args): ...@@ -89,50 +46,57 @@ def main(args):
def dict_path(lang): def dict_path(lang):
return dest_path("dict", lang) + ".txt" return dest_path("dict", lang) + ".txt"
if args.joined_dictionary: def build_dictionary(filenames, src=False, tgt=False):
assert not args.srcdict, "cannot combine --srcdict and --joined-dictionary" assert src ^ tgt
assert not args.tgtdict, "cannot combine --tgtdict and --joined-dictionary" return task.build_dictionary(
src_dict = build_dictionary( filenames,
{train_path(lang) for lang in [args.source_lang, args.target_lang]}, workers=args.workers,
args.workers, threshold=args.thresholdsrc if src else args.thresholdtgt,
nwords=args.nwordssrc if src else args.nwordstgt,
padding_factor=args.padding_factor,
) )
if args.joined_dictionary:
assert (
not args.srcdict or not args.tgtdict
), "cannot use both --srcdict and --tgtdict with --joined-dictionary"
if args.srcdict:
src_dict = task.load_dictionary(args.srcdict)
elif args.tgtdict:
src_dict = task.load_dictionary(args.tgtdict)
else:
assert (
args.trainpref
), "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary({train_path(lang) for lang in [args.source_lang, args.target_lang]}, src=True)
tgt_dict = src_dict tgt_dict = src_dict
else: else:
if args.srcdict: if args.srcdict:
src_dict = dictionary.Dictionary.load(args.srcdict) src_dict = task.load_dictionary(args.srcdict)
else: else:
assert ( assert (
args.trainpref args.trainpref
), "--trainpref must be set if --srcdict is not specified" ), "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary([train_path(args.source_lang)], args.workers) src_dict = build_dictionary([train_path(args.source_lang)], src=True)
if target: if target:
if args.tgtdict: if args.tgtdict:
tgt_dict = dictionary.Dictionary.load(args.tgtdict) tgt_dict = task.load_dictionary(args.tgtdict)
else: else:
assert ( assert (
args.trainpref args.trainpref
), "--trainpref must be set if --tgtdict is not specified" ), "--trainpref must be set if --tgtdict is not specified"
tgt_dict = build_dictionary( tgt_dict = build_dictionary([train_path(args.target_lang)], tgt=True)
[train_path(args.target_lang)], args.workers else:
) tgt_dict = None
src_dict.finalize(
threshold=args.thresholdsrc,
nwords=args.nwordssrc,
padding_factor=args.padding_factor,
)
src_dict.save(dict_path(args.source_lang)) src_dict.save(dict_path(args.source_lang))
if target: if target and tgt_dict is not None:
if not args.joined_dictionary:
tgt_dict.finalize(
threshold=args.thresholdtgt,
nwords=args.nwordstgt,
padding_factor=args.padding_factor,
)
tgt_dict.save(dict_path(args.target_lang)) tgt_dict.save(dict_path(args.target_lang))
def make_binary_dataset(input_prefix, output_prefix, lang, num_workers): def make_binary_dataset(input_prefix, output_prefix, lang, num_workers):
dict = dictionary.Dictionary.load(dict_path(lang)) dict = task.load_dictionary(dict_path(lang))
print("| [{}] Dictionary: {} types".format(lang, len(dict) - 1)) print("| [{}] Dictionary: {} types".format(lang, len(dict) - 1))
n_seq_tok = [0, 0] n_seq_tok = [0, 0]
replaced = Counter() replaced = Counter()
...@@ -229,8 +193,6 @@ def main(args): ...@@ -229,8 +193,6 @@ def main(args):
assert args.trainpref, "--trainpref must be set if --alignfile is specified" assert args.trainpref, "--trainpref must be set if --alignfile is specified"
src_file_name = train_path(args.source_lang) src_file_name = train_path(args.source_lang)
tgt_file_name = train_path(args.target_lang) tgt_file_name = train_path(args.target_lang)
src_dict = dictionary.Dictionary.load(dict_path(args.source_lang))
tgt_dict = dictionary.Dictionary.load(dict_path(args.target_lang))
freq_map = {} freq_map = {}
with open(args.alignfile, "r", encoding='utf-8') as align_file: with open(args.alignfile, "r", encoding='utf-8') as align_file:
with open(src_file_name, "r", encoding='utf-8') as src_file: with open(src_file_name, "r", encoding='utf-8') as src_file:
...@@ -260,37 +222,16 @@ def main(args): ...@@ -260,37 +222,16 @@ def main(args):
align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get) align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)
with open( with open(
os.path.join( os.path.join(
args.destdir, args.destdir,
"alignment.{}-{}.txt".format(args.source_lang, args.target_lang), "alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
), ),
"w", encoding='utf-8' "w", encoding='utf-8'
) as f: ) as f:
for k, v in align_dict.items(): for k, v in align_dict.items():
print("{} {}".format(src_dict[k], tgt_dict[v]), file=f) print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
def build_and_save_dictionary(
train_path, output_path, num_workers, freq_threshold, max_words, dict_cls=dictionary.Dictionary,
):
dict = build_dictionary([train_path], num_workers, dict_cls)
dict.finalize(threshold=freq_threshold, nwords=max_words)
dict_path = os.path.join(output_path, "dict.txt")
dict.save(dict_path)
return dict_path
def build_dictionary(
filenames,
workers,
dict_cls=dictionary.Dictionary,
):
d = dict_cls()
for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenize_line, workers)
return d
def binarize(args, filename, dict, output_prefix, lang, offset, end): def binarize(args, filename, dict, output_prefix, lang, offset, end):
ds = indexed_dataset.IndexedDatasetBuilder( ds = indexed_dataset.IndexedDatasetBuilder(
dataset_dest_file(args, output_prefix, lang, "bin") dataset_dest_file(args, output_prefix, lang, "bin")
...@@ -304,21 +245,6 @@ def binarize(args, filename, dict, output_prefix, lang, offset, end): ...@@ -304,21 +245,6 @@ def binarize(args, filename, dict, output_prefix, lang, offset, end):
return res return res
def binarize_with_load(
args,
filename,
dict_path,
output_prefix,
lang,
offset,
end,
dict_cls=dictionary.Dictionary,
):
dict = dict_cls.load(dict_path)
binarize(args, filename, dict, output_prefix, lang, offset, end)
return dataset_dest_prefix(args, output_prefix, lang)
def dataset_dest_prefix(args, output_prefix, lang): def dataset_dest_prefix(args, output_prefix, lang):
base = "{}/{}".format(args.destdir, output_prefix) base = "{}/{}".format(args.destdir, output_prefix)
lang_part = ( lang_part = (
...@@ -346,6 +272,6 @@ def merge_files(files, outpath): ...@@ -346,6 +272,6 @@ def merge_files(files, outpath):
if __name__ == "__main__": if __name__ == "__main__":
parser = get_parser() parser = options.get_preprocessing_parser()
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
...@@ -223,7 +223,7 @@ def create_dummy_data(data_dir, num_examples=1000, maxlen=20): ...@@ -223,7 +223,7 @@ def create_dummy_data(data_dir, num_examples=1000, maxlen=20):
def preprocess_translation_data(data_dir, extra_flags=None): def preprocess_translation_data(data_dir, extra_flags=None):
preprocess_parser = preprocess.get_parser() preprocess_parser = options.get_preprocessing_parser()
preprocess_args = preprocess_parser.parse_args( preprocess_args = preprocess_parser.parse_args(
[ [
'--source-lang', 'in', '--source-lang', 'in',
...@@ -291,7 +291,7 @@ def generate_main(data_dir, extra_flags=None): ...@@ -291,7 +291,7 @@ def generate_main(data_dir, extra_flags=None):
def preprocess_lm_data(data_dir): def preprocess_lm_data(data_dir):
preprocess_parser = preprocess.get_parser() preprocess_parser = options.get_preprocessing_parser()
preprocess_args = preprocess_parser.parse_args([ preprocess_args = preprocess_parser.parse_args([
'--only-source', '--only-source',
'--trainpref', os.path.join(data_dir, 'train.out'), '--trainpref', os.path.join(data_dir, 'train.out'),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment