Commit 85956b16 authored by Jing Li's avatar Jing Li Committed by A. Unique TensorFlower
Browse files

Update interleave hyperparameters

PiperOrigin-RevId: 265780130
parent 0fa5ff23
...@@ -94,8 +94,12 @@ def create_pretrain_dataset(file_paths, ...@@ -94,8 +94,12 @@ def create_pretrain_dataset(file_paths,
dataset = dataset.shuffle(len(file_paths)) dataset = dataset.shuffle(len(file_paths))
# In parallel, create tf record dataset for each train files. # In parallel, create tf record dataset for each train files.
# cycle_length = 8 means that up to 8 files will be read and deserialized in
# parallel. You may want to increase this number if you have a large number of
# CPU cores.
dataset = dataset.interleave( dataset = dataset.interleave(
tf.data.TFRecordDataset, cycle_length=tf.data.experimental.AUTOTUNE) tf.data.TFRecordDataset, cycle_length=8,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
decode_fn = lambda record: decode_record(record, name_to_features) decode_fn = lambda record: decode_record(record, name_to_features)
dataset = dataset.map( dataset = dataset.map(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment