Unverified Commit c92a7e16 authored by Hongkun Yu's avatar Hongkun Yu Committed by GitHub
Browse files

Merged commit includes the following changes: (#7347)

260862396  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Fix BERT pretraining input pipeline to shuffle and shard dataset properly for multi-worker training.

--

PiperOrigin-RevId: 260862396
parent 1b089751
...@@ -21,29 +21,30 @@ from __future__ import print_function ...@@ -21,29 +21,30 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
def file_based_input_fn_builder(input_file, name_to_features): def decode_record(record, name_to_features):
"""Creates an `input_fn` closure to be passed to TPUEstimator.""" """Decodes a record to a TensorFlow example."""
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
def _decode_record(record, name_to_features): return example
"""Decodes a record to a TensorFlow example."""
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example def file_based_input_fn_builder(input_file, name_to_features):
"""Creates an `input_fn` closure to be passed for BERT custom training."""
def input_fn(): def input_fn():
"""Returns dataset for training/evaluation.""" """Returns dataset for training/evaluation."""
# For training, we want a lot of parallel reading and shuffling. # For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter. # For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file) d = tf.data.TFRecordDataset(input_file)
d = d.map(lambda record: _decode_record(record, name_to_features)) d = d.map(lambda record: decode_record(record, name_to_features))
# When `input_file` is a path to a single file or a list # When `input_file` is a path to a single file or a list
# containing a single path, disable auto sharding so that # containing a single path, disable auto sharding so that
...@@ -57,11 +58,12 @@ def file_based_input_fn_builder(input_file, name_to_features): ...@@ -57,11 +58,12 @@ def file_based_input_fn_builder(input_file, name_to_features):
return input_fn return input_fn
def create_pretrain_dataset(file_path, def create_pretrain_dataset(file_paths,
seq_length, seq_length,
max_predictions_per_seq, max_predictions_per_seq,
batch_size, batch_size,
is_training=True): is_training=True,
input_pipeline_context=None):
"""Creates input dataset from (tf)records files for pretraining.""" """Creates input dataset from (tf)records files for pretraining."""
name_to_features = { name_to_features = {
'input_ids': 'input_ids':
...@@ -80,8 +82,24 @@ def create_pretrain_dataset(file_path, ...@@ -80,8 +82,24 @@ def create_pretrain_dataset(file_path,
tf.io.FixedLenFeature([1], tf.int64), tf.io.FixedLenFeature([1], tf.int64),
} }
input_fn = file_based_input_fn_builder(file_path, name_to_features) dataset = tf.data.Dataset.list_files(file_paths, shuffle=is_training)
dataset = input_fn()
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
input_pipeline_context.input_pipeline_id)
dataset = dataset.repeat()
# We set shuffle buffer to exactly match total number of
# training files to ensure that training data is well shuffled.
dataset = dataset.shuffle(len(file_paths))
# In parallel, create tf record dataset for each train files.
dataset = dataset.interleave(
tf.data.TFRecordDataset, cycle_length=tf.data.experimental.AUTOTUNE)
decode_fn = lambda record: decode_record(record, name_to_features)
dataset = dataset.map(
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
def _select_data_from_record(record): def _select_data_from_record(record):
"""Filter out features to use for pretraining.""" """Filter out features to use for pretraining."""
...@@ -99,11 +117,12 @@ def create_pretrain_dataset(file_path, ...@@ -99,11 +117,12 @@ def create_pretrain_dataset(file_path,
return (x, y) return (x, y)
dataset = dataset.map(_select_data_from_record) dataset = dataset.map(
_select_data_from_record,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
if is_training: if is_training:
dataset = dataset.shuffle(100) dataset = dataset.shuffle(100)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(1024) dataset = dataset.prefetch(1024)
......
...@@ -77,14 +77,18 @@ def get_pretrain_input_data(input_file_pattern, seq_length, ...@@ -77,14 +77,18 @@ def get_pretrain_input_data(input_file_pattern, seq_length,
batch_size = int(batch_size / strategy.num_replicas_in_sync) batch_size = int(batch_size / strategy.num_replicas_in_sync)
def _dataset_fn(ctx=None): def _dataset_fn(ctx=None):
del ctx """Returns tf.data.Dataset for distributed BERT pretraining."""
input_files = [] input_files = []
for input_pattern in input_file_pattern.split(','): for input_pattern in input_file_pattern.split(','):
input_files.extend(tf.io.gfile.glob(input_pattern)) input_files.extend(tf.io.gfile.glob(input_pattern))
train_dataset = input_pipeline.create_pretrain_dataset( train_dataset = input_pipeline.create_pretrain_dataset(
input_files, seq_length, max_predictions_per_seq, batch_size) input_files,
seq_length,
max_predictions_per_seq,
batch_size,
is_training=True,
input_pipeline_context=ctx)
return train_dataset return train_dataset
return _dataset_fn if use_dataset_fn else _dataset_fn() return _dataset_fn if use_dataset_fn else _dataset_fn()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment