Commit 745d5fbd authored by Myle Ott's avatar Myle Ott
Browse files

Pad dictionary to be a multiple of 8 in preprocessing

parent 4cd2bb70
...@@ -24,13 +24,6 @@ def tokenize_line(line): ...@@ -24,13 +24,6 @@ def tokenize_line(line):
class Tokenizer: class Tokenizer:
@staticmethod
def build_dictionary(filename, tokenize=tokenize_line):
dict = dictionary.Dictionary()
Tokenizer.add_file_to_dictionary(filename, dict, tokenize)
dict.finalize()
return dict
@staticmethod @staticmethod
def add_file_to_dictionary(filename, dict, tokenize): def add_file_to_dictionary(filename, dict, tokenize):
with open(filename, 'r') as f: with open(filename, 'r') as f:
......
...@@ -38,6 +38,7 @@ def get_parser(): ...@@ -38,6 +38,7 @@ def get_parser():
help='output format (optional)') help='output format (optional)')
parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary') parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary')
parser.add_argument('--only-source', action='store_true', help='Only process the source language') parser.add_argument('--only-source', action='store_true', help='Only process the source language')
parser.add_argument('--padding-factor', metavar='N', default=8, help='Pad dictionary size to be multiple of N')
return parser return parser
...@@ -46,30 +47,47 @@ def main(args): ...@@ -46,30 +47,47 @@ def main(args):
os.makedirs(args.destdir, exist_ok=True) os.makedirs(args.destdir, exist_ok=True)
target = not args.only_source target = not args.only_source
def pad_dictionary(d):
"""Pad dictionary to be a multiple of args.padding_factor.
Keeping the dictionary size a multiple of 8 improves performance on some
architectures, e.g., Nvidia Tensor Cores.
"""
if args.padding_factor > 1:
i = 0
while len(d) % args.padding_factor != 0:
d.add_symbol('madeupword{:04d}'.format(i))
i += 1
assert len(d) % args.padding_factor == 0
def build_dictionary(filenames):
d = dictionary.Dictionary()
for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenize_line)
pad_dictionary(d)
d.finalize()
return d
if args.joined_dictionary: if args.joined_dictionary:
assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary' assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary'
assert not args.tgtdict, 'cannot combine --tgtdict and --joined-dictionary' assert not args.tgtdict, 'cannot combine --tgtdict and --joined-dictionary'
src_dict = dictionary.Dictionary() src_dict = build_dictionary([
for lang in [args.source_lang, args.target_lang]: '{}.{}'.format(args.trainpref, lang)
Tokenizer.add_file_to_dictionary( for lang in [args.source_lang, args.target_lang]
filename='{}.{}'.format(args.trainpref, lang), ])
dict=src_dict,
tokenize=tokenize_line,
)
src_dict.finalize()
tgt_dict = src_dict tgt_dict = src_dict
else: else:
if args.srcdict: if args.srcdict:
src_dict = dictionary.Dictionary.load(args.srcdict) src_dict = dictionary.Dictionary.load(args.srcdict)
else: else:
assert args.trainpref, "--trainpref must be set if --srcdict is not specified" assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
src_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.source_lang)) src_dict = build_dictionary(['{}.{}'.format(args.trainpref, args.source_lang)])
if target: if target:
if args.tgtdict: if args.tgtdict:
tgt_dict = dictionary.Dictionary.load(args.tgtdict) tgt_dict = dictionary.Dictionary.load(args.tgtdict)
else: else:
assert args.trainpref, "--trainpref must be set if --tgtdict is not specified" assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
tgt_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.target_lang)) tgt_dict = build_dictionary(['{}.{}'.format(args.trainpref, args.target_lang)])
src_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)), src_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)),
threshold=args.thresholdsrc, nwords=args.nwordssrc) threshold=args.thresholdsrc, nwords=args.nwordssrc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment