Commit 880e7cd4 authored by Ruty Rinott's avatar Ruty Rinott Committed by Facebook Github Bot
Browse files

pipeline for LM training

Summary:
step 2 of pipeline for LM training
assumes tokenized text data as input. Splits it into train/validation/test, and runs binarization
(step a_ii in https://fb.quip.com/kazzAxvZHBj9)

Reviewed By: borguz

Differential Revision: D10454705

fbshipit-source-id: 74e8679041f5507c4e404c1b719547c2ae9ed983
parent 189fcabf
......@@ -21,31 +21,91 @@ from fairseq.tokenizer import Tokenizer, tokenize_line
from multiprocessing import Pool, Manager, Process
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', help='target language')
parser.add_argument('--trainpref', metavar='FP', default=None, help='train file prefix')
parser.add_argument('--validpref', metavar='FP', default=None, help='comma separated, valid file prefixes')
parser.add_argument('--testpref', metavar='FP', default=None, help='comma separated, test file prefixes')
parser.add_argument('--destdir', metavar='DIR', default='data-bin', help='destination dir')
parser.add_argument('--thresholdtgt', metavar='N', default=0, type=int,
help='map words appearing less than threshold times to unknown')
parser.add_argument('--thresholdsrc', metavar='N', default=0, type=int,
help='map words appearing less than threshold times to unknown')
parser.add_argument('--tgtdict', metavar='FP', help='reuse given target dictionary')
parser.add_argument('--srcdict', metavar='FP', help='reuse given source dictionary')
parser.add_argument('--nwordstgt', metavar='N', default=-1, type=int, help='number of target words to retain')
parser.add_argument('--nwordssrc', metavar='N', default=-1, type=int, help='number of source words to retain')
parser.add_argument('--alignfile', metavar='ALIGN', default=None, help='an alignment file (optional)')
parser.add_argument('--output-format', metavar='FORMAT', default='binary', choices=['binary', 'raw'],
help='output format (optional)')
parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary')
parser.add_argument('--only-source', action='store_true', help='Only process the source language')
parser.add_argument('--padding-factor', metavar='N', default=8, type=int,
help='Pad dictionary size to be multiple of N')
parser.add_argument('--workers', metavar='N', default=1, type=int, help='number of parallel workers')
parser.add_argument(
"-s", "--source-lang", default=None, metavar="SRC", help="source language"
)
parser.add_argument(
"-t", "--target-lang", default=None, metavar="TARGET", help="target language"
)
parser.add_argument(
"--trainpref", metavar="FP", default=None, help="train file prefix"
)
parser.add_argument(
"--validpref",
metavar="FP",
default=None,
help="comma separated, valid file prefixes",
)
parser.add_argument(
"--testpref",
metavar="FP",
default=None,
help="comma separated, test file prefixes",
)
parser.add_argument(
"--destdir", metavar="DIR", default="data-bin", help="destination dir"
)
parser.add_argument(
"--thresholdtgt",
metavar="N",
default=0,
type=int,
help="map words appearing less than threshold times to unknown",
)
parser.add_argument(
"--thresholdsrc",
metavar="N",
default=0,
type=int,
help="map words appearing less than threshold times to unknown",
)
parser.add_argument("--tgtdict", metavar="FP", help="reuse given target dictionary")
parser.add_argument("--srcdict", metavar="FP", help="reuse given source dictionary")
parser.add_argument(
"--nwordstgt",
metavar="N",
default=-1,
type=int,
help="number of target words to retain",
)
parser.add_argument(
"--nwordssrc",
metavar="N",
default=-1,
type=int,
help="number of source words to retain",
)
parser.add_argument(
"--alignfile",
metavar="ALIGN",
default=None,
help="an alignment file (optional)",
)
parser.add_argument(
"--output-format",
metavar="FORMAT",
default="binary",
choices=["binary", "raw"],
help="output format (optional)",
)
parser.add_argument(
"--joined-dictionary", action="store_true", help="Generate joined dictionary"
)
parser.add_argument(
"--only-source", action="store_true", help="Only process the source language"
)
parser.add_argument(
"--padding-factor",
metavar="N",
default=8,
type=int,
help="Pad dictionary size to be multiple of N",
)
parser.add_argument(
"--workers", metavar="N", default=1, type=int, help="number of parallel workers"
)
return parser
......@@ -54,47 +114,47 @@ def main(args):
os.makedirs(args.destdir, exist_ok=True)
target = not args.only_source
def build_dictionary(filenames):
d = dictionary.Dictionary()
for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenize_line, args.workers)
return d
def train_path(lang):
return '{}{}'.format(args.trainpref, ('.' + lang) if lang else '')
return "{}{}".format(args.trainpref, ("." + lang) if lang else "")
def file_name(prefix, lang):
fname = prefix
if lang is not None:
fname += f'.{lang}'
fname += f".{lang}"
return fname
def dest_path(prefix, lang):
return os.path.join(args.destdir, file_name(prefix, lang))
def dict_path(lang):
return dest_path('dict', lang) + '.txt'
return dest_path("dict", lang) + ".txt"
if args.joined_dictionary:
assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary'
assert not args.tgtdict, 'cannot combine --tgtdict and --joined-dictionary'
src_dict = build_dictionary(set([
train_path(lang)
for lang in [args.source_lang, args.target_lang]
]))
assert not args.srcdict, "cannot combine --srcdict and --joined-dictionary"
assert not args.tgtdict, "cannot combine --tgtdict and --joined-dictionary"
src_dict = build_dictionary(
{train_path(lang) for lang in [args.source_lang, args.target_lang]},
args.workers,
)
tgt_dict = src_dict
else:
if args.srcdict:
src_dict = dictionary.Dictionary.load(args.srcdict)
else:
assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary([train_path(args.source_lang)])
assert (
args.trainpref
), "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary([train_path(args.source_lang)], args.workers)
if target:
if args.tgtdict:
tgt_dict = dictionary.Dictionary.load(args.tgtdict)
else:
assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
tgt_dict = build_dictionary([train_path(args.target_lang)])
assert (
args.trainpref
), "--trainpref must be set if --tgtdict is not specified"
tgt_dict = build_dictionary(
[train_path(args.target_lang)], args.workers
)
src_dict.finalize(
threshold=args.thresholdsrc,
......@@ -113,30 +173,47 @@ def main(args):
def make_binary_dataset(input_prefix, output_prefix, lang, num_workers):
dict = dictionary.Dictionary.load(dict_path(lang))
print('| [{}] Dictionary: {} types'.format(lang, len(dict) - 1))
print("| [{}] Dictionary: {} types".format(lang, len(dict) - 1))
n_seq_tok = [0, 0]
replaced = Counter()
def merge_result(worker_result):
replaced.update(worker_result['replaced'])
n_seq_tok[0] += worker_result['nseq']
n_seq_tok[1] += worker_result['ntok']
replaced.update(worker_result["replaced"])
n_seq_tok[0] += worker_result["nseq"]
n_seq_tok[1] += worker_result["ntok"]
input_file = '{}{}'.format(input_prefix, ('.' + lang) if lang is not None else '')
input_file = "{}{}".format(
input_prefix, ("." + lang) if lang is not None else ""
)
offsets = Tokenizer.find_offsets(input_file, num_workers)
pool = None
if num_workers > 1:
pool = Pool(processes=num_workers-1)
pool = Pool(processes=num_workers - 1)
for worker_id in range(1, num_workers):
prefix = "{}{}".format(output_prefix, worker_id)
pool.apply_async(binarize, (args, input_file, dict, prefix, lang,
offsets[worker_id],
offsets[worker_id + 1]), callback=merge_result)
pool.apply_async(
binarize,
(
args,
input_file,
dict,
prefix,
lang,
offsets[worker_id],
offsets[worker_id + 1],
),
callback=merge_result,
)
pool.close()
ds = indexed_dataset.IndexedDatasetBuilder(dataset_dest_file(args, output_prefix, lang, 'bin'))
merge_result(Tokenizer.binarize(input_file, dict, lambda t: ds.add_item(t),
offset=0, end=offsets[1]))
ds = indexed_dataset.IndexedDatasetBuilder(
dataset_dest_file(args, output_prefix, lang, "bin")
)
merge_result(
Tokenizer.binarize(
input_file, dict, lambda t: ds.add_item(t), offset=0, end=offsets[1]
)
)
if num_workers > 1:
pool.join()
for worker_id in range(1, num_workers):
......@@ -146,44 +223,47 @@ def main(args):
os.remove(indexed_dataset.data_file_path(temp_file_path))
os.remove(indexed_dataset.index_file_path(temp_file_path))
ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
ds.finalize(dataset_dest_file(args, output_prefix, lang, 'idx'))
print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
lang, input_file, n_seq_tok[0], n_seq_tok[1],
100 * sum(replaced.values()) / n_seq_tok[1], dict.unk_word))
print(
"| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
lang,
input_file,
n_seq_tok[0],
n_seq_tok[1],
100 * sum(replaced.values()) / n_seq_tok[1],
dict.unk_word,
)
)
def make_dataset(input_prefix, output_prefix, lang, num_workers=1):
if args.output_format == 'binary':
if args.output_format == "binary":
make_binary_dataset(input_prefix, output_prefix, lang, num_workers)
elif args.output_format == 'raw':
elif args.output_format == "raw":
# Copy original text file to destination folder
output_text_file = dest_path(
output_prefix + '.{}-{}'.format(args.source_lang, args.target_lang),
output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
lang,
)
shutil.copyfile(file_name(input_prefix, lang), output_text_file)
def make_all(lang):
if args.trainpref:
make_dataset(args.trainpref, 'train', lang, num_workers=args.workers)
make_dataset(args.trainpref, "train", lang, num_workers=args.workers)
if args.validpref:
for k, validpref in enumerate(args.validpref.split(',')):
outprefix = 'valid{}'.format(k) if k > 0 else 'valid'
for k, validpref in enumerate(args.validpref.split(",")):
outprefix = "valid{}".format(k) if k > 0 else "valid"
make_dataset(validpref, outprefix, lang)
if args.testpref:
for k, testpref in enumerate(args.testpref.split(',')):
outprefix = 'test{}'.format(k) if k > 0 else 'test'
for k, testpref in enumerate(args.testpref.split(",")):
outprefix = "test{}".format(k) if k > 0 else "test"
make_dataset(testpref, outprefix, lang)
make_all(args.source_lang)
if target:
make_all(args.target_lang)
print('| Wrote preprocessed data to {}'.format(args.destdir))
print("| Wrote preprocessed data to {}".format(args.destdir))
if args.alignfile:
assert args.trainpref, "--trainpref must be set if --alignfile is specified"
......@@ -192,13 +272,13 @@ def main(args):
src_dict = dictionary.Dictionary.load(dict_path(args.source_lang))
tgt_dict = dictionary.Dictionary.load(dict_path(args.target_lang))
freq_map = {}
with open(args.alignfile, 'r') as align_file:
with open(src_file_name, 'r') as src_file:
with open(tgt_file_name, 'r') as tgt_file:
with open(args.alignfile, "r") as align_file:
with open(src_file_name, "r") as src_file:
with open(tgt_file_name, "r") as tgt_file:
for a, s, t in zip_longest(align_file, src_file, tgt_file):
si = Tokenizer.tokenize(s, src_dict, add_if_not_exist=False)
ti = Tokenizer.tokenize(t, tgt_dict, add_if_not_exist=False)
ai = list(map(lambda x: tuple(x.split('-')), a.split()))
ai = list(map(lambda x: tuple(x.split("-")), a.split()))
for sai, tai in ai:
srcidx = si[int(sai)]
tgtidx = ti[int(tai)]
......@@ -219,35 +299,80 @@ def main(args):
for srcidx in freq_map.keys():
align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)
with open(os.path.join(args.destdir, 'alignment.{}-{}.txt'.format(
args.source_lang, args.target_lang)), 'w') as f:
with open(
os.path.join(
args.destdir,
"alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
),
"w",
) as f:
for k, v in align_dict.items():
print('{} {}'.format(src_dict[k], tgt_dict[v]), file=f)
print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
def build_and_save_dictionary(
train_path, output_path, num_workers, freq_threshold, max_words
):
dict = build_dictionary([train_path], num_workers)
dict.finalize(threshold=freq_threshold, nwords=max_words)
dict_path = os.path.join(output_path, "dict.txt")
dict.save(dict_path)
return dict_path
def build_dictionary(filenames, workers):
d = dictionary.Dictionary()
for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenize_line, workers)
return d
def binarize(args, filename, dict, output_prefix, lang, offset, end):
ds = indexed_dataset.IndexedDatasetBuilder(
dataset_dest_file(args, output_prefix, lang, "bin")
)
ds = indexed_dataset.IndexedDatasetBuilder(dataset_dest_file(args, output_prefix, lang, 'bin'))
def consumer(tensor):
ds.add_item(tensor)
res = Tokenizer.binarize(filename, dict, consumer, offset=offset, end=end)
ds.finalize(dataset_dest_file(args, output_prefix, lang, 'idx'))
ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
return res
def binarize_with_load(args, filename, dict_path, output_prefix, lang, offset, end):
dict = dictionary.Dictionary.load(dict_path)
binarize(args, filename, dict, output_prefix, lang, offset, end)
return dataset_dest_prefix(args, output_prefix, lang)
def dataset_dest_prefix(args, output_prefix, lang):
base = f'{args.destdir}/{output_prefix}'
lang_part = f'.{args.source_lang}-{args.target_lang}.{lang}' if lang is not None else ''
return f'{base}{lang_part}'
base = f"{args.destdir}/{output_prefix}"
lang_part = (
f".{args.source_lang}-{args.target_lang}.{lang}" if lang is not None else ""
)
return f"{base}{lang_part}"
def dataset_dest_file(args, output_prefix, lang, extension):
base = dataset_dest_prefix(args, output_prefix, lang)
return f'{base}.{extension}'
return f"{base}.{extension}"
def get_offsets(input_file, num_workers):
return Tokenizer.find_offsets(input_file, num_workers)
def merge_files(files, outpath):
ds = indexed_dataset.IndexedDatasetBuilder("{}.bin".format(outpath))
for file in files:
ds.merge_file_(file)
os.remove(indexed_dataset.data_file_path(file))
os.remove(indexed_dataset.index_file_path(file))
ds.finalize("{}.idx".format(outpath))
if __name__ == '__main__':
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment