Commit 5a55f69d authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 310104070
parent 44c3e33f
......@@ -41,7 +41,9 @@ def single_file_dataset(input_file, name_to_features):
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
d = d.map(lambda record: decode_record(record, name_to_features))
d = d.map(
lambda record: decode_record(record, name_to_features),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
# When `input_file` is a path to a single file or a list
# containing a single path, disable auto sharding so that
......@@ -107,9 +109,13 @@ def create_pretrain_dataset(input_patterns,
# parallel. You may want to increase this number if you have a large number of
# CPU cores.
dataset = dataset.interleave(
tf.data.TFRecordDataset, cycle_length=8,
tf.data.TFRecordDataset,
cycle_length=8,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
if is_training:
dataset = dataset.shuffle(100)
decode_fn = lambda record: decode_record(record, name_to_features)
dataset = dataset.map(
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
......@@ -136,12 +142,8 @@ def create_pretrain_dataset(input_patterns,
dataset = dataset.map(
_select_data_from_record,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
if is_training:
dataset = dataset.shuffle(100)
dataset = dataset.batch(batch_size, drop_remainder=is_training)
dataset = dataset.prefetch(1024)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
......@@ -174,14 +176,15 @@ def create_classifier_dataset(file_path,
y = record['label_ids']
return (x, y)
dataset = dataset.map(_select_data_from_record)
if is_training:
dataset = dataset.shuffle(100)
dataset = dataset.repeat()
dataset = dataset.map(
_select_data_from_record,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=is_training)
dataset = dataset.prefetch(1024)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
......@@ -224,12 +227,13 @@ def create_squad_dataset(file_path,
x[name] = tensor
return (x, y)
dataset = dataset.map(_select_data_from_record)
if is_training:
dataset = dataset.shuffle(100)
dataset = dataset.repeat()
dataset = dataset.map(
_select_data_from_record,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(1024)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment