Commit 745d5fbd authored by Myle Ott's avatar Myle Ott
Browse files

Pad dictionary to be a multiple of 8 in preprocessing

parent 4cd2bb70
......@@ -24,13 +24,6 @@ def tokenize_line(line):
class Tokenizer:
@staticmethod
def build_dictionary(filename, tokenize=tokenize_line):
dict = dictionary.Dictionary()
Tokenizer.add_file_to_dictionary(filename, dict, tokenize)
dict.finalize()
return dict
@staticmethod
def add_file_to_dictionary(filename, dict, tokenize):
with open(filename, 'r') as f:
......
......@@ -38,6 +38,7 @@ def get_parser():
help='output format (optional)')
parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary')
parser.add_argument('--only-source', action='store_true', help='Only process the source language')
parser.add_argument('--padding-factor', metavar='N', default=8, help='Pad dictionary size to be multiple of N')
return parser
......@@ -46,30 +47,47 @@ def main(args):
os.makedirs(args.destdir, exist_ok=True)
target = not args.only_source
def pad_dictionary(d):
"""Pad dictionary to be a multiple of args.padding_factor.
Keeping the dictionary size a multiple of 8 improves performance on some
architectures, e.g., Nvidia Tensor Cores.
"""
if args.padding_factor > 1:
i = 0
while len(d) % args.padding_factor != 0:
d.add_symbol('madeupword{:04d}'.format(i))
i += 1
assert len(d) % args.padding_factor == 0
def build_dictionary(filenames):
d = dictionary.Dictionary()
for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenize_line)
pad_dictionary(d)
d.finalize()
return d
if args.joined_dictionary:
assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary'
assert not args.tgtdict, 'cannot combine --tgtdict and --joined-dictionary'
src_dict = dictionary.Dictionary()
for lang in [args.source_lang, args.target_lang]:
Tokenizer.add_file_to_dictionary(
filename='{}.{}'.format(args.trainpref, lang),
dict=src_dict,
tokenize=tokenize_line,
)
src_dict.finalize()
src_dict = build_dictionary([
'{}.{}'.format(args.trainpref, lang)
for lang in [args.source_lang, args.target_lang]
])
tgt_dict = src_dict
else:
if args.srcdict:
src_dict = dictionary.Dictionary.load(args.srcdict)
else:
assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
src_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.source_lang))
src_dict = build_dictionary(['{}.{}'.format(args.trainpref, args.source_lang)])
if target:
if args.tgtdict:
tgt_dict = dictionary.Dictionary.load(args.tgtdict)
else:
assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
tgt_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.target_lang))
tgt_dict = build_dictionary(['{}.{}'.format(args.trainpref, args.target_lang)])
src_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)),
threshold=args.thresholdsrc, nwords=args.nwordssrc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment