Commit 40fd5c4f authored by Chen Qian's avatar Chen Qian Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 474159392
parent 1aac52a3
...@@ -48,7 +48,7 @@ def build_optimizer( ...@@ -48,7 +48,7 @@ def build_optimizer(
`ExponentialMovingAverage`. `ExponentialMovingAverage`.
Returns: Returns:
A tf.keras.Optimizer. A tf.keras.optimizers.legacy.Optimizer.
Raises: Raises:
ValueError if the provided optimizer_name is not supported. ValueError if the provided optimizer_name is not supported.
...@@ -60,12 +60,12 @@ def build_optimizer( ...@@ -60,12 +60,12 @@ def build_optimizer(
if optimizer_name == 'sgd': if optimizer_name == 'sgd':
logging.info('Using SGD optimizer') logging.info('Using SGD optimizer')
nesterov = params.get('nesterov', False) nesterov = params.get('nesterov', False)
optimizer = tf.keras.optimizers.SGD( optimizer = tf.keras.optimizers.legacy.SGD(
learning_rate=base_learning_rate, nesterov=nesterov) learning_rate=base_learning_rate, nesterov=nesterov)
elif optimizer_name == 'momentum': elif optimizer_name == 'momentum':
logging.info('Using momentum optimizer') logging.info('Using momentum optimizer')
nesterov = params.get('nesterov', False) nesterov = params.get('nesterov', False)
optimizer = tf.keras.optimizers.SGD( optimizer = tf.keras.optimizers.legacy.SGD(
learning_rate=base_learning_rate, learning_rate=base_learning_rate,
momentum=params['momentum'], momentum=params['momentum'],
nesterov=nesterov) nesterov=nesterov)
...@@ -74,7 +74,7 @@ def build_optimizer( ...@@ -74,7 +74,7 @@ def build_optimizer(
rho = params.get('decay', None) or params.get('rho', 0.9) rho = params.get('decay', None) or params.get('rho', 0.9)
momentum = params.get('momentum', 0.9) momentum = params.get('momentum', 0.9)
epsilon = params.get('epsilon', 1e-07) epsilon = params.get('epsilon', 1e-07)
optimizer = tf.keras.optimizers.RMSprop( optimizer = tf.keras.optimizers.legacy.RMSprop(
learning_rate=base_learning_rate, learning_rate=base_learning_rate,
rho=rho, rho=rho,
momentum=momentum, momentum=momentum,
...@@ -84,7 +84,7 @@ def build_optimizer( ...@@ -84,7 +84,7 @@ def build_optimizer(
beta_1 = params.get('beta_1', 0.9) beta_1 = params.get('beta_1', 0.9)
beta_2 = params.get('beta_2', 0.999) beta_2 = params.get('beta_2', 0.999)
epsilon = params.get('epsilon', 1e-07) epsilon = params.get('epsilon', 1e-07)
optimizer = tf.keras.optimizers.Adam( optimizer = tf.keras.optimizers.legacy.Adam(
learning_rate=base_learning_rate, learning_rate=base_learning_rate,
beta_1=beta_1, beta_1=beta_1,
beta_2=beta_2, beta_2=beta_2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment