Commit 5c15ce77 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Fix a mistake in previous change

PiperOrigin-RevId: 281409019
parent 252e6384
......@@ -59,12 +59,10 @@ def get_pretrain_dataset_fn(input_file_pattern, seq_length,
"""Returns input dataset from input file string."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
input_files = []
for input_pattern in input_file_pattern.split(','):
input_files.extend(tf.io.gfile.glob(input_pattern))
input_patterns = input_file_pattern.split(',')
batch_size = ctx.get_per_replica_batch_size(global_batch_size)
train_dataset = input_pipeline.create_pretrain_dataset(
input_files,
input_patterns,
seq_length,
max_predictions_per_seq,
batch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment