"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "371d4e5df30dc55702ce812006e0624dbba9bbb0"
Commit 5c15ce77 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Fix a mistake in previous change

PiperOrigin-RevId: 281409019
parent 252e6384
...@@ -59,12 +59,10 @@ def get_pretrain_dataset_fn(input_file_pattern, seq_length, ...@@ -59,12 +59,10 @@ def get_pretrain_dataset_fn(input_file_pattern, seq_length,
"""Returns input dataset from input file string.""" """Returns input dataset from input file string."""
def _dataset_fn(ctx=None): def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining.""" """Returns tf.data.Dataset for distributed BERT pretraining."""
input_files = [] input_patterns = input_file_pattern.split(',')
for input_pattern in input_file_pattern.split(','):
input_files.extend(tf.io.gfile.glob(input_pattern))
batch_size = ctx.get_per_replica_batch_size(global_batch_size) batch_size = ctx.get_per_replica_batch_size(global_batch_size)
train_dataset = input_pipeline.create_pretrain_dataset( train_dataset = input_pipeline.create_pretrain_dataset(
input_files, input_patterns,
seq_length, seq_length,
max_predictions_per_seq, max_predictions_per_seq,
batch_size, batch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment