Commit 920defcc authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Use tf.compat.v1 version of enable_mixed_precision_graph_rewrite.

The function `tf.train.experimental.enable_mixed_precision_graph_rewrite` will be removed from the TF2 namespace soon, at which point it will only be accessible under tf.compat.v1.

PiperOrigin-RevId: 367046393
parent 17814043
...@@ -203,8 +203,9 @@ def run(flags_obj): ...@@ -203,8 +203,9 @@ def run(flags_obj):
# which will ensure tf.compat.v2.keras.mixed_precision and # which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up. # up.
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer = (
optimizer) tf.compat.v1.mixed_precision.enable_mixed_precision_graph_rewrite(
optimizer))
# TODO(hongkuny): Remove trivial model usage and move it to benchmark. # TODO(hongkuny): Remove trivial model usage and move it to benchmark.
if flags_obj.use_trivial_model: if flags_obj.use_trivial_model:
......
...@@ -75,7 +75,8 @@ def _run_benchmark(): ...@@ -75,7 +75,8 @@ def _run_benchmark():
with tf.distribute.MirroredStrategy().scope(): with tf.distribute.MirroredStrategy().scope():
model = tf.keras.applications.ResNet50(weights=None) model = tf.keras.applications.ResNet50(weights=None)
model.compile( model.compile(
optimizer=tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer=tf.compat.v1.mixed_precision
.enable_mixed_precision_graph_rewrite(
tf.keras.optimizers.Adam(), loss_scale="dynamic"), tf.keras.optimizers.Adam(), loss_scale="dynamic"),
loss="sparse_categorical_crossentropy", loss="sparse_categorical_crossentropy",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment