Commit 35daa566 authored by Hongkun Yu's avatar Hongkun Yu Committed by GitHub
Browse files

Revert "Revert "Update usage of tf.contrib.data to tf.data.experimental" (#7654)"

This reverts commit b4e560dc.
parent b4e560dc
...@@ -110,7 +110,7 @@ def process_record_dataset(dataset, ...@@ -110,7 +110,7 @@ def process_record_dataset(dataset,
# Operations between the final prefetch and the get_next call to the iterator # Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to # will happen synchronously during run time. We prefetch here again to
# background all of the above processing work and keep it out of the # background all of the above processing work and keep it out of the
# critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE # critical training path. Setting buffer_size to tf.data.experimental.AUTOTUNE
# allows DistributionStrategies to adjust how many batches to fetch based # allows DistributionStrategies to adjust how many batches to fetch based
# on how many devices are present. # on how many devices are present.
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
......
...@@ -183,7 +183,7 @@ def _batch_examples(dataset, batch_size, max_length): ...@@ -183,7 +183,7 @@ def _batch_examples(dataset, batch_size, max_length):
# lengths as well. Resulting lengths of inputs and targets can differ. # lengths as well. Resulting lengths of inputs and targets can differ.
return grouped_dataset.padded_batch(bucket_batch_size, ([None], [None])) return grouped_dataset.padded_batch(bucket_batch_size, ([None], [None]))
return dataset.apply(tf.contrib.data.group_by_window( return dataset.apply(tf.data.experimental.group_by_window(
key_func=example_to_bucket_id, key_func=example_to_bucket_id,
reduce_func=batching_fn, reduce_func=batching_fn,
window_size=None, window_size=None,
...@@ -223,7 +223,7 @@ def _read_and_batch_from_files( ...@@ -223,7 +223,7 @@ def _read_and_batch_from_files(
# Read files and interleave results. When training, the order of the examples # Read files and interleave results. When training, the order of the examples
# will be non-deterministic. # will be non-deterministic.
dataset = dataset.apply( dataset = dataset.apply(
tf.contrib.data.parallel_interleave( tf.data.experimental.parallel_interleave(
_load_records, sloppy=shuffle, cycle_length=num_parallel_calls)) _load_records, sloppy=shuffle, cycle_length=num_parallel_calls))
# Parse each tf.Example into a dictionary # Parse each tf.Example into a dictionary
...@@ -235,8 +235,9 @@ def _read_and_batch_from_files( ...@@ -235,8 +235,9 @@ def _read_and_batch_from_files(
dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length)) dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length))
if static_batch: if static_batch:
dataset = dataset.apply(tf.contrib.data.padded_batch_and_drop_remainder( dataset = dataset.padded_batch(
batch_size // max_length, ([max_length], [max_length]))) batch_size // max_length, ([max_length], [max_length]),
drop_remainder=True)
else: else:
# Group and batch such that each batch has examples of similar length. # Group and batch such that each batch has examples of similar length.
dataset = _batch_examples(dataset, batch_size, max_length) dataset = _batch_examples(dataset, batch_size, max_length)
...@@ -244,7 +245,7 @@ def _read_and_batch_from_files( ...@@ -244,7 +245,7 @@ def _read_and_batch_from_files(
dataset = dataset.repeat(repeat) dataset = dataset.repeat(repeat)
# Prefetch the next element to improve speed of input pipeline. # Prefetch the next element to improve speed of input pipeline.
dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return dataset return dataset
......
...@@ -318,7 +318,7 @@ class COCOGroundtruthGenerator(object): ...@@ -318,7 +318,7 @@ class COCOGroundtruthGenerator(object):
cycle_length=32, cycle_length=32,
sloppy=False)) sloppy=False))
dataset = dataset.map(self._parse_single_example, num_parallel_calls=64) dataset = dataset.map(self._parse_single_example, num_parallel_calls=64)
dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(1, drop_remainder=False) dataset = dataset.batch(1, drop_remainder=False)
return dataset return dataset
......
...@@ -128,7 +128,7 @@ def process_record_dataset(dataset, ...@@ -128,7 +128,7 @@ def process_record_dataset(dataset,
# Operations between the final prefetch and the get_next call to the iterator # Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to # will happen synchronously during run time. We prefetch here again to
# background all of the above processing work and keep it out of the # background all of the above processing work and keep it out of the
# critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE # critical training path. Setting buffer_size to tf.data.experimental.AUTOTUNE
# allows DistributionStrategies to adjust how many batches to fetch based # allows DistributionStrategies to adjust how many batches to fetch based
# on how many devices are present. # on how many devices are present.
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment