"torchvision/csrc/vscode:/vscode.git/clone" did not exist on "481ef519db4946826f3bc6dd6f28b3c4d2e4d402"
Commit 2488f2c7 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[XLNET] Fix tf.data bad usages.

PiperOrigin-RevId: 314177321
parent 440e7851
...@@ -93,26 +93,18 @@ def file_based_input_fn_builder(input_file, name_to_features, batch_size, ...@@ -93,26 +93,18 @@ def file_based_input_fn_builder(input_file, name_to_features, batch_size,
# file level shuffle # file level shuffle
d = d.shuffle(len(input_file)).repeat() d = d.shuffle(len(input_file)).repeat()
d = d.apply( d = d.interleave(
tf.data.experimental.parallel_interleave( tf.data.TFRecordDataset,
tf.data.TFRecordDataset, sloppy=is_training,
sloppy=is_training, cycle_length=cycle_length)
cycle_length=cycle_length))
if is_training: if is_training:
# sample level shuffle # sample level shuffle
d = d.shuffle(buffer_size=2048) d = d.shuffle(buffer_size=2048)
d = d.map(
# TODO(b/138223458): Hard-code drop_remainder=True to get around the bug lambda record: _decode_record(record, name_to_features),
# that under TPU strategy, setting drop_remainder=False in num_parallel_calls=tf.data.experimental.AUTOTUNE)
# tf.data.Dataset.batch() while data_size can be divided by global d = d.batch(batch_size, drop_remainder=is_training)
# batch_size will trigger dynamic_dimension related TPU compilation error.
d = d.apply(
tf.data.experimental.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
num_parallel_batches=num_threads,
drop_remainder=True))
# When `input_file` is a path to a single file or a list # When `input_file` is a path to a single file or a list
# containing a single path, disable auto sharding so that # containing a single path, disable auto sharding so that
...@@ -737,11 +729,7 @@ def parse_files_to_dataset(parser, ...@@ -737,11 +729,7 @@ def parse_files_to_dataset(parser,
logging.info("Perform sample-level shuffle with size %d", buffer_size) logging.info("Perform sample-level shuffle with size %d", buffer_size)
dataset = dataset.shuffle(buffer_size=buffer_size) dataset = dataset.shuffle(buffer_size=buffer_size)
# (zihang): since we are doing online preprocessing, the parsed result of dataset = dataset.cache().repeat().map(parser)
# the same input at each time will be different. Thus, cache processed data
# is not helpful. It will use a lot of memory and lead to contrainer OOM.
# So, change to cache non-parsed raw data instead.
dataset = dataset.cache().map(parser).repeat()
dataset = dataset.batch(bsz_per_core, drop_remainder=True) dataset = dataset.batch(bsz_per_core, drop_remainder=True)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment