Commit 5cdbcac3 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Use tf.compat.v1 version of enable_mixed_precision_graph_rewrite.

The function `tf.train.experimental.enable_mixed_precision_graph_rewrite` will be removed from the TF2 namespace soon, at which point it will only be accessible under tf.compat.v1.

PiperOrigin-RevId: 367046393
parent a8ae1619
......@@ -40,11 +40,11 @@ def configure_optimizer(optimizer,
optimizer, dynamic=False, initial_scale=loss_scale)
if use_graph_rewrite:
# Note: the model dtype must be 'float32', which will ensure
# tf.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer)
# tf.keras.mixed_precision and enable_mixed_precision_graph_rewrite do not
# double up.
optimizer = (
tf.compat.v1.mixed_precision.enable_mixed_precision_graph_rewrite(
optimizer))
return optimizer
......
......@@ -213,8 +213,8 @@ def define_performance(num_parallel_calls=False,
"When --dtype=fp16, how fp16 should be implemented. This has no "
"impact on correctness. 'keras' uses the "
"tf.keras.mixed_precision API. 'graph_rewrite' uses the "
"tf.train.experimental.enable_mixed_precision_graph_rewrite "
"API."))
"tf.compat.v1.mixed_precision."
"enable_mixed_precision_graph_rewrite API."))
@flags.multi_flags_validator(
["fp16_implementation", "dtype", "loss_scale"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment