Commit 84a43b13 authored by Mike Chrzanowski's avatar Mike Chrzanowski Committed by Jared Casper
Browse files

bug fixes in partitioned data preprocessor

parent 8ce8256f
...@@ -15,8 +15,10 @@ def build_tokenizer(args): ...@@ -15,8 +15,10 @@ def build_tokenizer(args):
print('> building {} tokenizer ...'.format(args.tokenizer_type), print('> building {} tokenizer ...'.format(args.tokenizer_type),
flush=True) flush=True)
if args.tokenizer_type != 'SentencePieceTokenizer':
assert args.vocab_file is not None
# Select and instantiate the tokenizer. # Select and instantiate the tokenizer.
assert args.vocab_file is not None
if args.tokenizer_type == 'BertWordPieceLowerCase': if args.tokenizer_type == 'BertWordPieceLowerCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=True, lower_case=True,
......
...@@ -174,6 +174,7 @@ class Partition(object): ...@@ -174,6 +174,7 @@ class Partition(object):
self.print_processing_stats(i, proc_start, total_bytes_processed) self.print_processing_stats(i, proc_start, total_bytes_processed)
fin.close() fin.close()
builders[key].finalize(output_idx_files[key])
def get_args(): def get_args():
...@@ -219,9 +220,8 @@ def get_args(): ...@@ -219,9 +220,8 @@ def get_args():
args = parser.parse_args() args = parser.parse_args()
args.keep_empty = False args.keep_empty = False
if (args.tokenizer_type.lower().startswith('bert') if args.tokenizer_type.lower().startswith('bert') and not args.split_sentences:
if not args.split_sentences: print("Are you sure you don't want to split sentences?")
print("Are you sure you don't want to split sentences?")
# some default/dummy values for the tokenizer # some default/dummy values for the tokenizer
args.rank = 1 args.rank = 1
...@@ -265,7 +265,11 @@ def main(): ...@@ -265,7 +265,11 @@ def main():
if args.partitions == 1: if args.partitions == 1:
file_name, extension = os.path.splitext(args.input) file_name, extension = os.path.splitext(args.input)
sentence_split_file = file_name + "_ss" + extension sentence_split_file = file_name + "_ss" + extension
in_ss_out_names.append((args.input, sentence_split_file, args.output_prefix)) file_names = {
'partition': args.input,
'sentence_split': sentence_split_file,
'output_prefix': args.output_prefix}
in_ss_out_names.append(file_names)
else: else:
in_file_names = glob.glob(args.input) in_file_names = glob.glob(args.input)
...@@ -358,7 +362,7 @@ def main(): ...@@ -358,7 +362,7 @@ def main():
full_partition_output_prefix = "{}_{}_{}".format(parition_output_prefix, full_partition_output_prefix = "{}_{}_{}".format(parition_output_prefix,
key, level) key, level)
builders[key].merge_file_(full_partition_output_prefix) builders[key].merge_file_(full_partition_output_prefix)
builder[key].finalize(output_idx_files[key]) builders[key].finalize(output_idx_files[key])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment