"vscode:/vscode.git/clone" did not exist on "0b5cbdcfe07b06b1c539ee4e35102f53ad0373cb"
Commit 2488f2c7 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[XLNET] Fix tf.data bad usages.

PiperOrigin-RevId: 314177321
parent 440e7851
......@@ -93,26 +93,18 @@ def file_based_input_fn_builder(input_file, name_to_features, batch_size,
# file level shuffle
d = d.shuffle(len(input_file)).repeat()
d = d.apply(
tf.data.experimental.parallel_interleave(
tf.data.TFRecordDataset,
sloppy=is_training,
cycle_length=cycle_length))
d = d.interleave(
tf.data.TFRecordDataset,
sloppy=is_training,
cycle_length=cycle_length)
if is_training:
# sample level shuffle
d = d.shuffle(buffer_size=2048)
# TODO(b/138223458): Hard-code drop_remainder=True to get around the bug
# that under TPU strategy, setting drop_remainder=False in
# tf.data.Dataset.batch() while data_size can be divided by global
# batch_size will trigger dynamic_dimension related TPU compilation error.
d = d.apply(
tf.data.experimental.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
num_parallel_batches=num_threads,
drop_remainder=True))
d = d.map(
lambda record: _decode_record(record, name_to_features),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
d = d.batch(batch_size, drop_remainder=is_training)
# When `input_file` is a path to a single file or a list
# containing a single path, disable auto sharding so that
......@@ -737,11 +729,7 @@ def parse_files_to_dataset(parser,
logging.info("Perform sample-level shuffle with size %d", buffer_size)
dataset = dataset.shuffle(buffer_size=buffer_size)
# (zihang): since we are doing online preprocessing, the parsed result of
# the same input at each time will be different. Thus, cache processed data
# is not helpful. It will use a lot of memory and lead to contrainer OOM.
# So, change to cache non-parsed raw data instead.
dataset = dataset.cache().map(parser).repeat()
dataset = dataset.cache().repeat().map(parser)
dataset = dataset.batch(bsz_per_core, drop_remainder=True)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment