Commit 84266860 authored by Shining Sun's avatar Shining Sun
Browse files

set learning rate for eager mode

parent b9e30b11
......@@ -112,7 +112,6 @@ def run_cifar_with_keras(flags_obj):
Raises:
ValueError: If fp16 is passed as it is not currently supported.
"""
print(">>>>>>>>>>>>>>>>>>>>> eager: ", flags_obj.enable_eager)
if flags_obj.enable_eager:
tf.enable_eager_execution()
......@@ -206,7 +205,7 @@ def run_cifar_with_keras(flags_obj):
def define_keras_cifar_flags():
flags.DEFINE_boolean(name='enable_eager', default=True, help='Enable eager?')
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
def main(_):
......
......@@ -110,7 +110,8 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function should be float.')
if lr != self.prev_lr:
tf.keras.backend.set_value(self.model.optimizer.learning_rate, lr)
self.model.optimizer.learning_rate = lr # lr should be a float here
# tf.keras.backend.set_value(self.model.optimizer.learning_rate, lr)
self.prev_lr = lr
tf.logging.debug('Epoch %05d Batch %05d: LearningRateBatchScheduler change '
'learning rate to %s.', self.epochs, batch, lr)
......@@ -134,11 +135,10 @@ def get_optimizer_loss_and_metrics():
def get_dist_strategy():
if FLAGS.num_gpus == 1 and FLAGS.dist_strat_off:
if True: # FLAGS.num_gpus == 1 and FLAGS.dist_strat_off:
print('Not using distribution strategies.')
strategy = None
else:
print(">>>>>>>>>>>>>>>>>>strategy!!!!!!! ", FLAGS.num_gpus)
strategy = distribution_utils.get_distribution_strategy(
num_gpus=FLAGS.num_gpus)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment