Unverified Commit 9bdfb04a authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

V1 optimizer fix (#6350)

* optimizer back to compat.v1

* add doc string to fix lint
parent 0b0dc7f5
......@@ -266,8 +266,6 @@ def learning_rate_with_decay(
false_fn=lambda: lr)
return lr
def poly_rate_fn(global_step):
"""Handles linear scaling rule, gradual warmup, and LR decay.
......@@ -277,10 +275,10 @@ def learning_rate_with_decay(
decay schedule with power 2.0.
Args:
global_step: the current global_step
global_step: the current global_step
Returns:
returns the current learning rate
returns the current learning rate
"""
# Learning rate schedule for LARS polynomial schedule
......@@ -318,7 +316,6 @@ def learning_rate_with_decay(
if flags.FLAGS.enable_lars:
return poly_rate_fn
return learning_rate_fn
......@@ -360,6 +357,7 @@ def resnet_model_fn(features, labels, mode, model_class,
from the loss.
dtype: the TensorFlow dtype to use for calculations.
fine_tune: If True only train the dense layers(final layers).
label_smoothing: If greater than 0 then smooth the labels.
Returns:
EstimatorSpec parameterized according to the input params and the
......@@ -402,7 +400,7 @@ def resnet_model_fn(features, labels, mode, model_class,
logits=logits, onehot_labels=one_hot_labels,
label_smoothing=label_smoothing)
else:
cross_entropy = tf.losses.sparse_softmax_cross_entropy(
cross_entropy = tf.compat.v1.losses.sparse_softmax_cross_entropy(
logits=logits, labels=labels)
# Create a tensor named cross_entropy for logging purposes.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment