Commit 920defcc authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Use tf.compat.v1 version of enable_mixed_precision_graph_rewrite.

The function `tf.train.experimental.enable_mixed_precision_graph_rewrite` will be removed from the TF2 namespace soon, at which point it will only be accessible under tf.compat.v1.

PiperOrigin-RevId: 367046393
parent 17814043
......@@ -203,8 +203,9 @@ def run(flags_obj):
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer)
optimizer = (
tf.compat.v1.mixed_precision.enable_mixed_precision_graph_rewrite(
optimizer))
# TODO(hongkuny): Remove trivial model usage and move it to benchmark.
if flags_obj.use_trivial_model:
......
......@@ -75,7 +75,8 @@ def _run_benchmark():
with tf.distribute.MirroredStrategy().scope():
model = tf.keras.applications.ResNet50(weights=None)
model.compile(
optimizer=tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer=tf.compat.v1.mixed_precision
.enable_mixed_precision_graph_rewrite(
tf.keras.optimizers.Adam(), loss_scale="dynamic"),
loss="sparse_categorical_crossentropy",
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment