"docs/source/functions/segment_coo.rst" did not exist on "3875ae693164abd0de5011de6ac19df8af92dd2d"
Unverified Commit d44b7283 authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Fix graph.rewrite for TF 2.0 and remove num_parallel_batches (#7019)

* tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite

* Remove num_parallel_batches which is not used.
parent 097c8051
......@@ -113,7 +113,6 @@ def input_fn(is_training,
num_epochs=1,
dtype=tf.float32,
datasets_num_private_threads=None,
num_parallel_batches=1,
parse_record_fn=parse_record,
input_context=None):
"""Input function which provides batches for train or eval.
......@@ -125,7 +124,6 @@ def input_fn(is_training,
num_epochs: The number of epochs to repeat the dataset.
dtype: Data type to use for images/features
datasets_num_private_threads: Number of private threads for tf.data.
num_parallel_batches: Number of parallel batches for tf.data.
parse_record_fn: Function to use for parsing the records.
input_context: A `tf.distribute.InputContext` object passed in by
`tf.distribute.Strategy`.
......@@ -151,8 +149,7 @@ def input_fn(is_training,
parse_record_fn=parse_record_fn,
num_epochs=num_epochs,
dtype=dtype,
datasets_num_private_threads=datasets_num_private_threads,
num_parallel_batches=num_parallel_batches
datasets_num_private_threads=datasets_num_private_threads
)
......
......@@ -165,7 +165,6 @@ def input_fn(is_training,
num_epochs=1,
dtype=tf.float32,
datasets_num_private_threads=None,
num_parallel_batches=1,
parse_record_fn=parse_record,
input_context=None,
drop_remainder=False,
......@@ -179,7 +178,6 @@ def input_fn(is_training,
num_epochs: The number of epochs to repeat the dataset.
dtype: Data type to use for images/features
datasets_num_private_threads: Number of private threads for tf.data.
num_parallel_batches: Number of parallel batches for tf.data.
parse_record_fn: Function to use for parsing the records.
input_context: A `tf.distribute.InputContext` object passed in by
`tf.distribute.Strategy`.
......@@ -223,7 +221,6 @@ def input_fn(is_training,
num_epochs=num_epochs,
dtype=dtype,
datasets_num_private_threads=datasets_num_private_threads,
num_parallel_batches=num_parallel_batches,
drop_remainder=drop_remainder,
tf_data_experimental_slack=tf_data_experimental_slack,
)
......
......@@ -54,7 +54,6 @@ def process_record_dataset(dataset,
num_epochs=1,
dtype=tf.float32,
datasets_num_private_threads=None,
num_parallel_batches=1,
drop_remainder=False,
tf_data_experimental_slack=False):
"""Given a Dataset with raw records, return an iterator over the records.
......@@ -72,7 +71,6 @@ def process_record_dataset(dataset,
dtype: Data type to use for images/features.
datasets_num_private_threads: Number of threads for a private
threadpool created for all datasets computation.
num_parallel_batches: Number of parallel batches for tf.data.
drop_remainder: A boolean indicates whether to drop the remainder of the
batches. If True, the batch dimension will be static.
tf_data_experimental_slack: Whether to enable tf.data's
......@@ -462,7 +460,7 @@ def resnet_model_fn(features, labels, mode, model_class,
fp16_implementation = getattr(flags.FLAGS, 'fp16_implementation', None)
if fp16_implementation == 'graph_rewrite':
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer = tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer, loss_scale=loss_scale)
def _dense_grad_filter(gvs):
......@@ -539,6 +537,7 @@ def resnet_main(
shape: list of ints representing the shape of the images used for training.
This is only used if flags_obj.export_dir is passed.
Returns:
Dict of results of the run. Contains the keys `eval_results` and
`train_hooks`. `eval_results` contains accuracy (top_1) and accuracy_top_5.
`train_hooks` is a list the instances of hooks used during training.
......@@ -628,7 +627,6 @@ def resnet_main(
num_epochs=num_epochs,
dtype=flags_core.get_tf_dtype(flags_obj),
datasets_num_private_threads=flags_obj.datasets_num_private_threads,
num_parallel_batches=flags_obj.datasets_num_parallel_batches,
input_context=input_context)
def input_fn_eval():
......@@ -730,7 +728,6 @@ def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False,
flags_core.define_performance(num_parallel_calls=False,
tf_gpu_thread_mode=True,
datasets_num_private_threads=True,
datasets_num_parallel_batches=True,
dynamic_loss_scale=dynamic_loss_scale,
fp16_implementation=fp16_implementation,
loss_scale=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment