Unverified Commit 9bdfb04a authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

V1 optimizer fix (#6350)

* optimizer back to compat.v1

* add doc string to fix lint
parent 0b0dc7f5
...@@ -266,8 +266,6 @@ def learning_rate_with_decay( ...@@ -266,8 +266,6 @@ def learning_rate_with_decay(
false_fn=lambda: lr) false_fn=lambda: lr)
return lr return lr
def poly_rate_fn(global_step): def poly_rate_fn(global_step):
"""Handles linear scaling rule, gradual warmup, and LR decay. """Handles linear scaling rule, gradual warmup, and LR decay.
...@@ -277,10 +275,10 @@ def learning_rate_with_decay( ...@@ -277,10 +275,10 @@ def learning_rate_with_decay(
decay schedule with power 2.0. decay schedule with power 2.0.
Args: Args:
global_step: the current global_step global_step: the current global_step
Returns: Returns:
returns the current learning rate returns the current learning rate
""" """
# Learning rate schedule for LARS polynomial schedule # Learning rate schedule for LARS polynomial schedule
...@@ -318,7 +316,6 @@ def learning_rate_with_decay( ...@@ -318,7 +316,6 @@ def learning_rate_with_decay(
if flags.FLAGS.enable_lars: if flags.FLAGS.enable_lars:
return poly_rate_fn return poly_rate_fn
return learning_rate_fn return learning_rate_fn
...@@ -360,6 +357,7 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -360,6 +357,7 @@ def resnet_model_fn(features, labels, mode, model_class,
from the loss. from the loss.
dtype: the TensorFlow dtype to use for calculations. dtype: the TensorFlow dtype to use for calculations.
fine_tune: If True only train the dense layers(final layers). fine_tune: If True only train the dense layers(final layers).
label_smoothing: If greater than 0 then smooth the labels.
Returns: Returns:
EstimatorSpec parameterized according to the input params and the EstimatorSpec parameterized according to the input params and the
...@@ -402,7 +400,7 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -402,7 +400,7 @@ def resnet_model_fn(features, labels, mode, model_class,
logits=logits, onehot_labels=one_hot_labels, logits=logits, onehot_labels=one_hot_labels,
label_smoothing=label_smoothing) label_smoothing=label_smoothing)
else: else:
cross_entropy = tf.losses.sparse_softmax_cross_entropy( cross_entropy = tf.compat.v1.losses.sparse_softmax_cross_entropy(
logits=logits, labels=labels) logits=logits, labels=labels)
# Create a tensor named cross_entropy for logging purposes. # Create a tensor named cross_entropy for logging purposes.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment