Commit c069fab0 authored by Chen Qian's avatar Chen Qian Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 471682920
parent b766f6f4
...@@ -106,10 +106,14 @@ class PiecewiseConstantDecayWithWarmup( ...@@ -106,10 +106,14 @@ class PiecewiseConstantDecayWithWarmup(
} }
def get_optimizer(learning_rate=0.1): def get_optimizer(learning_rate=0.1, use_legacy_optimizer=True):
"""Returns optimizer to use.""" """Returns optimizer to use."""
# The learning_rate is overwritten at the beginning of each step by callback. # The learning_rate is overwritten at the beginning of each step by callback.
return tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9) if use_legacy_optimizer:
return tf.keras.optimizers.legacy.SGD(
learning_rate=learning_rate, momentum=0.9)
else:
return tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)
def get_callbacks(pruning_method=None, def get_callbacks(pruning_method=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment