Commit faf4bbb3 authored by Derek Murray's avatar Derek Murray Committed by Hongkun Yu
Browse files

Update tf.contrib.data to tf.data.experimental. (#7650)

parent cb136c62
......@@ -110,7 +110,7 @@ def process_record_dataset(dataset,
# Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to
# background all of the above processing work and keep it out of the
# critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE
# critical training path. Setting buffer_size to tf.data.experimental.AUTOTUNE
# allows DistributionStrategies to adjust how many batches to fetch based
# on how many devices are present.
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
......
......@@ -183,7 +183,7 @@ def _batch_examples(dataset, batch_size, max_length):
# lengths as well. Resulting lengths of inputs and targets can differ.
return grouped_dataset.padded_batch(bucket_batch_size, ([None], [None]))
return dataset.apply(tf.contrib.data.group_by_window(
return dataset.apply(tf.data.experimental.group_by_window(
key_func=example_to_bucket_id,
reduce_func=batching_fn,
window_size=None,
......@@ -223,7 +223,7 @@ def _read_and_batch_from_files(
# Read files and interleave results. When training, the order of the examples
# will be non-deterministic.
dataset = dataset.apply(
tf.contrib.data.parallel_interleave(
tf.data.experimental.parallel_interleave(
_load_records, sloppy=shuffle, cycle_length=num_parallel_calls))
# Parse each tf.Example into a dictionary
......@@ -235,8 +235,9 @@ def _read_and_batch_from_files(
dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length))
if static_batch:
dataset = dataset.apply(tf.contrib.data.padded_batch_and_drop_remainder(
batch_size // max_length, ([max_length], [max_length])))
dataset = dataset.padded_batch(
batch_size // max_length, ([max_length], [max_length]),
drop_remainder=True)
else:
# Group and batch such that each batch has examples of similar length.
dataset = _batch_examples(dataset, batch_size, max_length)
......@@ -244,7 +245,7 @@ def _read_and_batch_from_files(
dataset = dataset.repeat(repeat)
# Prefetch the next element to improve speed of input pipeline.
dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return dataset
......
......@@ -318,7 +318,7 @@ class COCOGroundtruthGenerator(object):
cycle_length=32,
sloppy=False))
dataset = dataset.map(self._parse_single_example, num_parallel_calls=64)
dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(1, drop_remainder=False)
return dataset
......
......@@ -128,7 +128,7 @@ def process_record_dataset(dataset,
# Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to
# background all of the above processing work and keep it out of the
# critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE
# critical training path. Setting buffer_size to tf.data.experimental.AUTOTUNE
# allows DistributionStrategies to adjust how many batches to fetch based
# on how many devices are present.
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
......
......@@ -207,13 +207,13 @@ def imagenet_input(split, batch_size, image_size, is_training):
# Read the data from disk in parallel
dataset = dataset.apply(
tf.contrib.data.parallel_interleave(
tf.data.experimental.parallel_interleave(
fetch_dataset, cycle_length=4, sloppy=True))
dataset = dataset.shuffle(1024)
# Parse, preprocess, and batch the data in parallel
dataset = dataset.apply(
tf.contrib.data.map_and_batch(
tf.data.experimental.map_and_batch(
lambda value: imagenet_parser(value, image_size, is_training),
batch_size=batch_size,
num_parallel_batches=4,
......@@ -231,7 +231,7 @@ def imagenet_input(split, batch_size, image_size, is_training):
dataset = dataset.map(set_shapes)
# Prefetch overlaps in-feed with training
dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
......
......@@ -115,7 +115,7 @@ def tiny_imagenet_input(split, batch_size, image_size, is_training):
dataset = dataset.repeat()
dataset = dataset.apply(
tf.contrib.data.map_and_batch(
tf.data.experimental.map_and_batch(
lambda value: tiny_imagenet_parser(value, image_size, is_training),
batch_size=batch_size,
num_parallel_batches=4,
......@@ -132,7 +132,7 @@ def tiny_imagenet_input(split, batch_size, image_size, is_training):
# Assign static batch size dimension
dataset = dataset.map(set_shapes)
dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
......
......@@ -270,6 +270,5 @@ def input_fn(batch_size, deep_speech_dataset, repeat=1):
)
# Prefetch to improve speed of input pipeline.
dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return dataset
......@@ -49,8 +49,8 @@ def read_dataset(file_read_func, input_files, config):
"""Reads a dataset, and handles repetition and shuffling.
Args:
file_read_func: Function to use in tf.contrib.data.parallel_interleave, to
read every individual file into a tf.data.Dataset.
file_read_func: Function to use in tf.data.experimental.parallel_interleave,
to read every individual file into a tf.data.Dataset.
input_files: A list of file paths to read.
config: A input_reader_builder.InputReader object.
......@@ -79,7 +79,7 @@ def read_dataset(file_read_func, input_files, config):
'still slightly shuffled since `num_readers` > 1.')
filename_dataset = filename_dataset.repeat(config.num_epochs or None)
records_dataset = filename_dataset.apply(
tf.contrib.data.parallel_interleave(
tf.data.experimental.parallel_interleave(
file_read_func,
cycle_length=num_readers,
block_length=config.read_block_length,
......@@ -154,8 +154,7 @@ def build(input_reader_config, batch_size=None, transform_input_data_fn=None):
data_map_fn = dataset.map
dataset = data_map_fn(process_fn, num_parallel_calls=num_parallel_calls)
if batch_size:
dataset = dataset.apply(
tf.contrib.data.batch_and_drop_remainder(batch_size))
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(input_reader_config.num_prefetch_batches)
return dataset
......
......@@ -309,8 +309,7 @@ class InputDataset(object):
if self._ensure_constant_batch_size:
# Only take batches of *exactly* size batch_size; then we get a
# statically knowable batch shape.
dataset = dataset.apply(
tf.contrib.data.batch_and_drop_remainder(batch_size))
dataset = dataset.batch(batch_size, drop_remainder=True)
else:
dataset = dataset.batch(batch_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment