Commit 5a55f69d authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 310104070
parent 44c3e33f
...@@ -41,7 +41,9 @@ def single_file_dataset(input_file, name_to_features): ...@@ -41,7 +41,9 @@ def single_file_dataset(input_file, name_to_features):
# For training, we want a lot of parallel reading and shuffling. # For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter. # For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file) d = tf.data.TFRecordDataset(input_file)
d = d.map(lambda record: decode_record(record, name_to_features)) d = d.map(
lambda record: decode_record(record, name_to_features),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
# When `input_file` is a path to a single file or a list # When `input_file` is a path to a single file or a list
# containing a single path, disable auto sharding so that # containing a single path, disable auto sharding so that
...@@ -107,9 +109,13 @@ def create_pretrain_dataset(input_patterns, ...@@ -107,9 +109,13 @@ def create_pretrain_dataset(input_patterns,
# parallel. You may want to increase this number if you have a large number of # parallel. You may want to increase this number if you have a large number of
# CPU cores. # CPU cores.
dataset = dataset.interleave( dataset = dataset.interleave(
tf.data.TFRecordDataset, cycle_length=8, tf.data.TFRecordDataset,
cycle_length=8,
num_parallel_calls=tf.data.experimental.AUTOTUNE) num_parallel_calls=tf.data.experimental.AUTOTUNE)
if is_training:
dataset = dataset.shuffle(100)
decode_fn = lambda record: decode_record(record, name_to_features) decode_fn = lambda record: decode_record(record, name_to_features)
dataset = dataset.map( dataset = dataset.map(
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
...@@ -136,12 +142,8 @@ def create_pretrain_dataset(input_patterns, ...@@ -136,12 +142,8 @@ def create_pretrain_dataset(input_patterns,
dataset = dataset.map( dataset = dataset.map(
_select_data_from_record, _select_data_from_record,
num_parallel_calls=tf.data.experimental.AUTOTUNE) num_parallel_calls=tf.data.experimental.AUTOTUNE)
if is_training:
dataset = dataset.shuffle(100)
dataset = dataset.batch(batch_size, drop_remainder=is_training) dataset = dataset.batch(batch_size, drop_remainder=is_training)
dataset = dataset.prefetch(1024) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset return dataset
...@@ -174,14 +176,15 @@ def create_classifier_dataset(file_path, ...@@ -174,14 +176,15 @@ def create_classifier_dataset(file_path,
y = record['label_ids'] y = record['label_ids']
return (x, y) return (x, y)
dataset = dataset.map(_select_data_from_record)
if is_training: if is_training:
dataset = dataset.shuffle(100) dataset = dataset.shuffle(100)
dataset = dataset.repeat() dataset = dataset.repeat()
dataset = dataset.map(
_select_data_from_record,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=is_training) dataset = dataset.batch(batch_size, drop_remainder=is_training)
dataset = dataset.prefetch(1024) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset return dataset
...@@ -224,12 +227,13 @@ def create_squad_dataset(file_path, ...@@ -224,12 +227,13 @@ def create_squad_dataset(file_path,
x[name] = tensor x[name] = tensor
return (x, y) return (x, y)
dataset = dataset.map(_select_data_from_record)
if is_training: if is_training:
dataset = dataset.shuffle(100) dataset = dataset.shuffle(100)
dataset = dataset.repeat() dataset = dataset.repeat()
dataset = dataset.map(
_select_data_from_record,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(1024) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset return dataset
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment