Commit 880e7cd4 authored by Ruty Rinott's avatar Ruty Rinott Committed by Facebook Github Bot
Browse files

pipeline for LM training

Summary:
step 2 of pipeline for LM training
assumes tokenized text data as input. Splits it into train/validation/test, and runs binarization
(step a_ii in https://fb.quip.com/kazzAxvZHBj9)

Reviewed By: borguz

Differential Revision: D10454705

fbshipit-source-id: 74e8679041f5507c4e404c1b719547c2ae9ed983
parent 189fcabf
...@@ -21,31 +21,91 @@ from fairseq.tokenizer import Tokenizer, tokenize_line ...@@ -21,31 +21,91 @@ from fairseq.tokenizer import Tokenizer, tokenize_line
from multiprocessing import Pool, Manager, Process from multiprocessing import Pool, Manager, Process
def get_parser(): def get_parser():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', help='source language') parser.add_argument(
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', help='target language') "-s", "--source-lang", default=None, metavar="SRC", help="source language"
parser.add_argument('--trainpref', metavar='FP', default=None, help='train file prefix') )
parser.add_argument('--validpref', metavar='FP', default=None, help='comma separated, valid file prefixes') parser.add_argument(
parser.add_argument('--testpref', metavar='FP', default=None, help='comma separated, test file prefixes') "-t", "--target-lang", default=None, metavar="TARGET", help="target language"
parser.add_argument('--destdir', metavar='DIR', default='data-bin', help='destination dir') )
parser.add_argument('--thresholdtgt', metavar='N', default=0, type=int, parser.add_argument(
help='map words appearing less than threshold times to unknown') "--trainpref", metavar="FP", default=None, help="train file prefix"
parser.add_argument('--thresholdsrc', metavar='N', default=0, type=int, )
help='map words appearing less than threshold times to unknown') parser.add_argument(
parser.add_argument('--tgtdict', metavar='FP', help='reuse given target dictionary') "--validpref",
parser.add_argument('--srcdict', metavar='FP', help='reuse given source dictionary') metavar="FP",
parser.add_argument('--nwordstgt', metavar='N', default=-1, type=int, help='number of target words to retain') default=None,
parser.add_argument('--nwordssrc', metavar='N', default=-1, type=int, help='number of source words to retain') help="comma separated, valid file prefixes",
parser.add_argument('--alignfile', metavar='ALIGN', default=None, help='an alignment file (optional)') )
parser.add_argument('--output-format', metavar='FORMAT', default='binary', choices=['binary', 'raw'], parser.add_argument(
help='output format (optional)') "--testpref",
parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary') metavar="FP",
parser.add_argument('--only-source', action='store_true', help='Only process the source language') default=None,
parser.add_argument('--padding-factor', metavar='N', default=8, type=int, help="comma separated, test file prefixes",
help='Pad dictionary size to be multiple of N') )
parser.add_argument('--workers', metavar='N', default=1, type=int, help='number of parallel workers') parser.add_argument(
"--destdir", metavar="DIR", default="data-bin", help="destination dir"
)
parser.add_argument(
"--thresholdtgt",
metavar="N",
default=0,
type=int,
help="map words appearing less than threshold times to unknown",
)
parser.add_argument(
"--thresholdsrc",
metavar="N",
default=0,
type=int,
help="map words appearing less than threshold times to unknown",
)
parser.add_argument("--tgtdict", metavar="FP", help="reuse given target dictionary")
parser.add_argument("--srcdict", metavar="FP", help="reuse given source dictionary")
parser.add_argument(
"--nwordstgt",
metavar="N",
default=-1,
type=int,
help="number of target words to retain",
)
parser.add_argument(
"--nwordssrc",
metavar="N",
default=-1,
type=int,
help="number of source words to retain",
)
parser.add_argument(
"--alignfile",
metavar="ALIGN",
default=None,
help="an alignment file (optional)",
)
parser.add_argument(
"--output-format",
metavar="FORMAT",
default="binary",
choices=["binary", "raw"],
help="output format (optional)",
)
parser.add_argument(
"--joined-dictionary", action="store_true", help="Generate joined dictionary"
)
parser.add_argument(
"--only-source", action="store_true", help="Only process the source language"
)
parser.add_argument(
"--padding-factor",
metavar="N",
default=8,
type=int,
help="Pad dictionary size to be multiple of N",
)
parser.add_argument(
"--workers", metavar="N", default=1, type=int, help="number of parallel workers"
)
return parser return parser
...@@ -54,47 +114,47 @@ def main(args): ...@@ -54,47 +114,47 @@ def main(args):
os.makedirs(args.destdir, exist_ok=True) os.makedirs(args.destdir, exist_ok=True)
target = not args.only_source target = not args.only_source
def build_dictionary(filenames):
d = dictionary.Dictionary()
for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenize_line, args.workers)
return d
def train_path(lang): def train_path(lang):
return '{}{}'.format(args.trainpref, ('.' + lang) if lang else '') return "{}{}".format(args.trainpref, ("." + lang) if lang else "")
def file_name(prefix, lang): def file_name(prefix, lang):
fname = prefix fname = prefix
if lang is not None: if lang is not None:
fname += f'.{lang}' fname += f".{lang}"
return fname return fname
def dest_path(prefix, lang): def dest_path(prefix, lang):
return os.path.join(args.destdir, file_name(prefix, lang)) return os.path.join(args.destdir, file_name(prefix, lang))
def dict_path(lang): def dict_path(lang):
return dest_path('dict', lang) + '.txt' return dest_path("dict", lang) + ".txt"
if args.joined_dictionary: if args.joined_dictionary:
assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary' assert not args.srcdict, "cannot combine --srcdict and --joined-dictionary"
assert not args.tgtdict, 'cannot combine --tgtdict and --joined-dictionary' assert not args.tgtdict, "cannot combine --tgtdict and --joined-dictionary"
src_dict = build_dictionary(set([ src_dict = build_dictionary(
train_path(lang) {train_path(lang) for lang in [args.source_lang, args.target_lang]},
for lang in [args.source_lang, args.target_lang] args.workers,
])) )
tgt_dict = src_dict tgt_dict = src_dict
else: else:
if args.srcdict: if args.srcdict:
src_dict = dictionary.Dictionary.load(args.srcdict) src_dict = dictionary.Dictionary.load(args.srcdict)
else: else:
assert args.trainpref, "--trainpref must be set if --srcdict is not specified" assert (
src_dict = build_dictionary([train_path(args.source_lang)]) args.trainpref
), "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary([train_path(args.source_lang)], args.workers)
if target: if target:
if args.tgtdict: if args.tgtdict:
tgt_dict = dictionary.Dictionary.load(args.tgtdict) tgt_dict = dictionary.Dictionary.load(args.tgtdict)
else: else:
assert args.trainpref, "--trainpref must be set if --tgtdict is not specified" assert (
tgt_dict = build_dictionary([train_path(args.target_lang)]) args.trainpref
), "--trainpref must be set if --tgtdict is not specified"
tgt_dict = build_dictionary(
[train_path(args.target_lang)], args.workers
)
src_dict.finalize( src_dict.finalize(
threshold=args.thresholdsrc, threshold=args.thresholdsrc,
...@@ -113,30 +173,47 @@ def main(args): ...@@ -113,30 +173,47 @@ def main(args):
def make_binary_dataset(input_prefix, output_prefix, lang, num_workers): def make_binary_dataset(input_prefix, output_prefix, lang, num_workers):
dict = dictionary.Dictionary.load(dict_path(lang)) dict = dictionary.Dictionary.load(dict_path(lang))
print('| [{}] Dictionary: {} types'.format(lang, len(dict) - 1)) print("| [{}] Dictionary: {} types".format(lang, len(dict) - 1))
n_seq_tok = [0, 0] n_seq_tok = [0, 0]
replaced = Counter() replaced = Counter()
def merge_result(worker_result): def merge_result(worker_result):
replaced.update(worker_result['replaced']) replaced.update(worker_result["replaced"])
n_seq_tok[0] += worker_result['nseq'] n_seq_tok[0] += worker_result["nseq"]
n_seq_tok[1] += worker_result['ntok'] n_seq_tok[1] += worker_result["ntok"]
input_file = '{}{}'.format(input_prefix, ('.' + lang) if lang is not None else '') input_file = "{}{}".format(
input_prefix, ("." + lang) if lang is not None else ""
)
offsets = Tokenizer.find_offsets(input_file, num_workers) offsets = Tokenizer.find_offsets(input_file, num_workers)
pool = None pool = None
if num_workers > 1: if num_workers > 1:
pool = Pool(processes=num_workers-1) pool = Pool(processes=num_workers - 1)
for worker_id in range(1, num_workers): for worker_id in range(1, num_workers):
prefix = "{}{}".format(output_prefix, worker_id) prefix = "{}{}".format(output_prefix, worker_id)
pool.apply_async(binarize, (args, input_file, dict, prefix, lang, pool.apply_async(
binarize,
(
args,
input_file,
dict,
prefix,
lang,
offsets[worker_id], offsets[worker_id],
offsets[worker_id + 1]), callback=merge_result) offsets[worker_id + 1],
),
callback=merge_result,
)
pool.close() pool.close()
ds = indexed_dataset.IndexedDatasetBuilder(dataset_dest_file(args, output_prefix, lang, 'bin')) ds = indexed_dataset.IndexedDatasetBuilder(
merge_result(Tokenizer.binarize(input_file, dict, lambda t: ds.add_item(t), dataset_dest_file(args, output_prefix, lang, "bin")
offset=0, end=offsets[1])) )
merge_result(
Tokenizer.binarize(
input_file, dict, lambda t: ds.add_item(t), offset=0, end=offsets[1]
)
)
if num_workers > 1: if num_workers > 1:
pool.join() pool.join()
for worker_id in range(1, num_workers): for worker_id in range(1, num_workers):
...@@ -146,44 +223,47 @@ def main(args): ...@@ -146,44 +223,47 @@ def main(args):
os.remove(indexed_dataset.data_file_path(temp_file_path)) os.remove(indexed_dataset.data_file_path(temp_file_path))
os.remove(indexed_dataset.index_file_path(temp_file_path)) os.remove(indexed_dataset.index_file_path(temp_file_path))
ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
ds.finalize(dataset_dest_file(args, output_prefix, lang, 'idx')) print(
"| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
lang,
print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format( input_file,
lang, input_file, n_seq_tok[0], n_seq_tok[1], n_seq_tok[0],
100 * sum(replaced.values()) / n_seq_tok[1], dict.unk_word)) n_seq_tok[1],
100 * sum(replaced.values()) / n_seq_tok[1],
dict.unk_word,
)
)
def make_dataset(input_prefix, output_prefix, lang, num_workers=1): def make_dataset(input_prefix, output_prefix, lang, num_workers=1):
if args.output_format == 'binary': if args.output_format == "binary":
make_binary_dataset(input_prefix, output_prefix, lang, num_workers) make_binary_dataset(input_prefix, output_prefix, lang, num_workers)
elif args.output_format == 'raw': elif args.output_format == "raw":
# Copy original text file to destination folder # Copy original text file to destination folder
output_text_file = dest_path( output_text_file = dest_path(
output_prefix + '.{}-{}'.format(args.source_lang, args.target_lang), output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
lang, lang,
) )
shutil.copyfile(file_name(input_prefix, lang), output_text_file) shutil.copyfile(file_name(input_prefix, lang), output_text_file)
def make_all(lang): def make_all(lang):
if args.trainpref: if args.trainpref:
make_dataset(args.trainpref, 'train', lang, num_workers=args.workers) make_dataset(args.trainpref, "train", lang, num_workers=args.workers)
if args.validpref: if args.validpref:
for k, validpref in enumerate(args.validpref.split(',')): for k, validpref in enumerate(args.validpref.split(",")):
outprefix = 'valid{}'.format(k) if k > 0 else 'valid' outprefix = "valid{}".format(k) if k > 0 else "valid"
make_dataset(validpref, outprefix, lang) make_dataset(validpref, outprefix, lang)
if args.testpref: if args.testpref:
for k, testpref in enumerate(args.testpref.split(',')): for k, testpref in enumerate(args.testpref.split(",")):
outprefix = 'test{}'.format(k) if k > 0 else 'test' outprefix = "test{}".format(k) if k > 0 else "test"
make_dataset(testpref, outprefix, lang) make_dataset(testpref, outprefix, lang)
make_all(args.source_lang) make_all(args.source_lang)
if target: if target:
make_all(args.target_lang) make_all(args.target_lang)
print('| Wrote preprocessed data to {}'.format(args.destdir)) print("| Wrote preprocessed data to {}".format(args.destdir))
if args.alignfile: if args.alignfile:
assert args.trainpref, "--trainpref must be set if --alignfile is specified" assert args.trainpref, "--trainpref must be set if --alignfile is specified"
...@@ -192,13 +272,13 @@ def main(args): ...@@ -192,13 +272,13 @@ def main(args):
src_dict = dictionary.Dictionary.load(dict_path(args.source_lang)) src_dict = dictionary.Dictionary.load(dict_path(args.source_lang))
tgt_dict = dictionary.Dictionary.load(dict_path(args.target_lang)) tgt_dict = dictionary.Dictionary.load(dict_path(args.target_lang))
freq_map = {} freq_map = {}
with open(args.alignfile, 'r') as align_file: with open(args.alignfile, "r") as align_file:
with open(src_file_name, 'r') as src_file: with open(src_file_name, "r") as src_file:
with open(tgt_file_name, 'r') as tgt_file: with open(tgt_file_name, "r") as tgt_file:
for a, s, t in zip_longest(align_file, src_file, tgt_file): for a, s, t in zip_longest(align_file, src_file, tgt_file):
si = Tokenizer.tokenize(s, src_dict, add_if_not_exist=False) si = Tokenizer.tokenize(s, src_dict, add_if_not_exist=False)
ti = Tokenizer.tokenize(t, tgt_dict, add_if_not_exist=False) ti = Tokenizer.tokenize(t, tgt_dict, add_if_not_exist=False)
ai = list(map(lambda x: tuple(x.split('-')), a.split())) ai = list(map(lambda x: tuple(x.split("-")), a.split()))
for sai, tai in ai: for sai, tai in ai:
srcidx = si[int(sai)] srcidx = si[int(sai)]
tgtidx = ti[int(tai)] tgtidx = ti[int(tai)]
...@@ -219,35 +299,80 @@ def main(args): ...@@ -219,35 +299,80 @@ def main(args):
for srcidx in freq_map.keys(): for srcidx in freq_map.keys():
align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get) align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)
with open(os.path.join(args.destdir, 'alignment.{}-{}.txt'.format( with open(
args.source_lang, args.target_lang)), 'w') as f: os.path.join(
args.destdir,
"alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
),
"w",
) as f:
for k, v in align_dict.items(): for k, v in align_dict.items():
print('{} {}'.format(src_dict[k], tgt_dict[v]), file=f) print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
def build_and_save_dictionary(
train_path, output_path, num_workers, freq_threshold, max_words
):
dict = build_dictionary([train_path], num_workers)
dict.finalize(threshold=freq_threshold, nwords=max_words)
dict_path = os.path.join(output_path, "dict.txt")
dict.save(dict_path)
return dict_path
def build_dictionary(filenames, workers):
d = dictionary.Dictionary()
for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenize_line, workers)
return d
def binarize(args, filename, dict, output_prefix, lang, offset, end): def binarize(args, filename, dict, output_prefix, lang, offset, end):
ds = indexed_dataset.IndexedDatasetBuilder(
dataset_dest_file(args, output_prefix, lang, "bin")
)
ds = indexed_dataset.IndexedDatasetBuilder(dataset_dest_file(args, output_prefix, lang, 'bin'))
def consumer(tensor): def consumer(tensor):
ds.add_item(tensor) ds.add_item(tensor)
res = Tokenizer.binarize(filename, dict, consumer, offset=offset, end=end) res = Tokenizer.binarize(filename, dict, consumer, offset=offset, end=end)
ds.finalize(dataset_dest_file(args, output_prefix, lang, 'idx')) ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
return res return res
def binarize_with_load(args, filename, dict_path, output_prefix, lang, offset, end):
dict = dictionary.Dictionary.load(dict_path)
binarize(args, filename, dict, output_prefix, lang, offset, end)
return dataset_dest_prefix(args, output_prefix, lang)
def dataset_dest_prefix(args, output_prefix, lang): def dataset_dest_prefix(args, output_prefix, lang):
base = f'{args.destdir}/{output_prefix}' base = f"{args.destdir}/{output_prefix}"
lang_part = f'.{args.source_lang}-{args.target_lang}.{lang}' if lang is not None else '' lang_part = (
return f'{base}{lang_part}' f".{args.source_lang}-{args.target_lang}.{lang}" if lang is not None else ""
)
return f"{base}{lang_part}"
def dataset_dest_file(args, output_prefix, lang, extension): def dataset_dest_file(args, output_prefix, lang, extension):
base = dataset_dest_prefix(args, output_prefix, lang) base = dataset_dest_prefix(args, output_prefix, lang)
return f'{base}.{extension}' return f"{base}.{extension}"
def get_offsets(input_file, num_workers):
return Tokenizer.find_offsets(input_file, num_workers)
def merge_files(files, outpath):
ds = indexed_dataset.IndexedDatasetBuilder("{}.bin".format(outpath))
for file in files:
ds.merge_file_(file)
os.remove(indexed_dataset.data_file_path(file))
os.remove(indexed_dataset.index_file_path(file))
ds.finalize("{}.idx".format(outpath))
if __name__ == '__main__': if __name__ == "__main__":
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment