Commit faf4bbb3 authored by Derek Murray's avatar Derek Murray Committed by Hongkun Yu
Browse files

Update tf.contrib.data to tf.data.experimental. (#7650)

parent cb136c62
...@@ -110,7 +110,7 @@ def process_record_dataset(dataset, ...@@ -110,7 +110,7 @@ def process_record_dataset(dataset,
# Operations between the final prefetch and the get_next call to the iterator # Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to # will happen synchronously during run time. We prefetch here again to
# background all of the above processing work and keep it out of the # background all of the above processing work and keep it out of the
# critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE # critical training path. Setting buffer_size to tf.data.experimental.AUTOTUNE
# allows DistributionStrategies to adjust how many batches to fetch based # allows DistributionStrategies to adjust how many batches to fetch based
# on how many devices are present. # on how many devices are present.
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
......
...@@ -183,7 +183,7 @@ def _batch_examples(dataset, batch_size, max_length): ...@@ -183,7 +183,7 @@ def _batch_examples(dataset, batch_size, max_length):
# lengths as well. Resulting lengths of inputs and targets can differ. # lengths as well. Resulting lengths of inputs and targets can differ.
return grouped_dataset.padded_batch(bucket_batch_size, ([None], [None])) return grouped_dataset.padded_batch(bucket_batch_size, ([None], [None]))
return dataset.apply(tf.contrib.data.group_by_window( return dataset.apply(tf.data.experimental.group_by_window(
key_func=example_to_bucket_id, key_func=example_to_bucket_id,
reduce_func=batching_fn, reduce_func=batching_fn,
window_size=None, window_size=None,
...@@ -223,7 +223,7 @@ def _read_and_batch_from_files( ...@@ -223,7 +223,7 @@ def _read_and_batch_from_files(
# Read files and interleave results. When training, the order of the examples # Read files and interleave results. When training, the order of the examples
# will be non-deterministic. # will be non-deterministic.
dataset = dataset.apply( dataset = dataset.apply(
tf.contrib.data.parallel_interleave( tf.data.experimental.parallel_interleave(
_load_records, sloppy=shuffle, cycle_length=num_parallel_calls)) _load_records, sloppy=shuffle, cycle_length=num_parallel_calls))
# Parse each tf.Example into a dictionary # Parse each tf.Example into a dictionary
...@@ -235,8 +235,9 @@ def _read_and_batch_from_files( ...@@ -235,8 +235,9 @@ def _read_and_batch_from_files(
dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length)) dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length))
if static_batch: if static_batch:
dataset = dataset.apply(tf.contrib.data.padded_batch_and_drop_remainder( dataset = dataset.padded_batch(
batch_size // max_length, ([max_length], [max_length]))) batch_size // max_length, ([max_length], [max_length]),
drop_remainder=True)
else: else:
# Group and batch such that each batch has examples of similar length. # Group and batch such that each batch has examples of similar length.
dataset = _batch_examples(dataset, batch_size, max_length) dataset = _batch_examples(dataset, batch_size, max_length)
...@@ -244,7 +245,7 @@ def _read_and_batch_from_files( ...@@ -244,7 +245,7 @@ def _read_and_batch_from_files(
dataset = dataset.repeat(repeat) dataset = dataset.repeat(repeat)
# Prefetch the next element to improve speed of input pipeline. # Prefetch the next element to improve speed of input pipeline.
dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return dataset return dataset
......
...@@ -318,7 +318,7 @@ class COCOGroundtruthGenerator(object): ...@@ -318,7 +318,7 @@ class COCOGroundtruthGenerator(object):
cycle_length=32, cycle_length=32,
sloppy=False)) sloppy=False))
dataset = dataset.map(self._parse_single_example, num_parallel_calls=64) dataset = dataset.map(self._parse_single_example, num_parallel_calls=64)
dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(1, drop_remainder=False) dataset = dataset.batch(1, drop_remainder=False)
return dataset return dataset
......
...@@ -128,7 +128,7 @@ def process_record_dataset(dataset, ...@@ -128,7 +128,7 @@ def process_record_dataset(dataset,
# Operations between the final prefetch and the get_next call to the iterator # Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to # will happen synchronously during run time. We prefetch here again to
# background all of the above processing work and keep it out of the # background all of the above processing work and keep it out of the
# critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE # critical training path. Setting buffer_size to tf.data.experimental.AUTOTUNE
# allows DistributionStrategies to adjust how many batches to fetch based # allows DistributionStrategies to adjust how many batches to fetch based
# on how many devices are present. # on how many devices are present.
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
......
...@@ -207,13 +207,13 @@ def imagenet_input(split, batch_size, image_size, is_training): ...@@ -207,13 +207,13 @@ def imagenet_input(split, batch_size, image_size, is_training):
# Read the data from disk in parallel # Read the data from disk in parallel
dataset = dataset.apply( dataset = dataset.apply(
tf.contrib.data.parallel_interleave( tf.data.experimental.parallel_interleave(
fetch_dataset, cycle_length=4, sloppy=True)) fetch_dataset, cycle_length=4, sloppy=True))
dataset = dataset.shuffle(1024) dataset = dataset.shuffle(1024)
# Parse, preprocess, and batch the data in parallel # Parse, preprocess, and batch the data in parallel
dataset = dataset.apply( dataset = dataset.apply(
tf.contrib.data.map_and_batch( tf.data.experimental.map_and_batch(
lambda value: imagenet_parser(value, image_size, is_training), lambda value: imagenet_parser(value, image_size, is_training),
batch_size=batch_size, batch_size=batch_size,
num_parallel_batches=4, num_parallel_batches=4,
...@@ -231,7 +231,7 @@ def imagenet_input(split, batch_size, image_size, is_training): ...@@ -231,7 +231,7 @@ def imagenet_input(split, batch_size, image_size, is_training):
dataset = dataset.map(set_shapes) dataset = dataset.map(set_shapes)
# Prefetch overlaps in-feed with training # Prefetch overlaps in-feed with training
dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset return dataset
......
...@@ -115,7 +115,7 @@ def tiny_imagenet_input(split, batch_size, image_size, is_training): ...@@ -115,7 +115,7 @@ def tiny_imagenet_input(split, batch_size, image_size, is_training):
dataset = dataset.repeat() dataset = dataset.repeat()
dataset = dataset.apply( dataset = dataset.apply(
tf.contrib.data.map_and_batch( tf.data.experimental.map_and_batch(
lambda value: tiny_imagenet_parser(value, image_size, is_training), lambda value: tiny_imagenet_parser(value, image_size, is_training),
batch_size=batch_size, batch_size=batch_size,
num_parallel_batches=4, num_parallel_batches=4,
...@@ -132,7 +132,7 @@ def tiny_imagenet_input(split, batch_size, image_size, is_training): ...@@ -132,7 +132,7 @@ def tiny_imagenet_input(split, batch_size, image_size, is_training):
# Assign static batch size dimension # Assign static batch size dimension
dataset = dataset.map(set_shapes) dataset = dataset.map(set_shapes)
dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset return dataset
......
...@@ -270,6 +270,5 @@ def input_fn(batch_size, deep_speech_dataset, repeat=1): ...@@ -270,6 +270,5 @@ def input_fn(batch_size, deep_speech_dataset, repeat=1):
) )
# Prefetch to improve speed of input pipeline. # Prefetch to improve speed of input pipeline.
dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return dataset return dataset
...@@ -49,8 +49,8 @@ def read_dataset(file_read_func, input_files, config): ...@@ -49,8 +49,8 @@ def read_dataset(file_read_func, input_files, config):
"""Reads a dataset, and handles repetition and shuffling. """Reads a dataset, and handles repetition and shuffling.
Args: Args:
file_read_func: Function to use in tf.contrib.data.parallel_interleave, to file_read_func: Function to use in tf.data.experimental.parallel_interleave,
read every individual file into a tf.data.Dataset. to read every individual file into a tf.data.Dataset.
input_files: A list of file paths to read. input_files: A list of file paths to read.
config: A input_reader_builder.InputReader object. config: A input_reader_builder.InputReader object.
...@@ -79,7 +79,7 @@ def read_dataset(file_read_func, input_files, config): ...@@ -79,7 +79,7 @@ def read_dataset(file_read_func, input_files, config):
'still slightly shuffled since `num_readers` > 1.') 'still slightly shuffled since `num_readers` > 1.')
filename_dataset = filename_dataset.repeat(config.num_epochs or None) filename_dataset = filename_dataset.repeat(config.num_epochs or None)
records_dataset = filename_dataset.apply( records_dataset = filename_dataset.apply(
tf.contrib.data.parallel_interleave( tf.data.experimental.parallel_interleave(
file_read_func, file_read_func,
cycle_length=num_readers, cycle_length=num_readers,
block_length=config.read_block_length, block_length=config.read_block_length,
...@@ -154,8 +154,7 @@ def build(input_reader_config, batch_size=None, transform_input_data_fn=None): ...@@ -154,8 +154,7 @@ def build(input_reader_config, batch_size=None, transform_input_data_fn=None):
data_map_fn = dataset.map data_map_fn = dataset.map
dataset = data_map_fn(process_fn, num_parallel_calls=num_parallel_calls) dataset = data_map_fn(process_fn, num_parallel_calls=num_parallel_calls)
if batch_size: if batch_size:
dataset = dataset.apply( dataset = dataset.batch(batch_size, drop_remainder=True)
tf.contrib.data.batch_and_drop_remainder(batch_size))
dataset = dataset.prefetch(input_reader_config.num_prefetch_batches) dataset = dataset.prefetch(input_reader_config.num_prefetch_batches)
return dataset return dataset
......
...@@ -309,8 +309,7 @@ class InputDataset(object): ...@@ -309,8 +309,7 @@ class InputDataset(object):
if self._ensure_constant_batch_size: if self._ensure_constant_batch_size:
# Only take batches of *exactly* size batch_size; then we get a # Only take batches of *exactly* size batch_size; then we get a
# statically knowable batch shape. # statically knowable batch shape.
dataset = dataset.apply( dataset = dataset.batch(batch_size, drop_remainder=True)
tf.contrib.data.batch_and_drop_remainder(batch_size))
else: else:
dataset = dataset.batch(batch_size) dataset = dataset.batch(batch_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment