Commit 84266860 authored by Shining Sun's avatar Shining Sun
Browse files

set learning rate for eager mode

parent b9e30b11
...@@ -112,7 +112,6 @@ def run_cifar_with_keras(flags_obj): ...@@ -112,7 +112,6 @@ def run_cifar_with_keras(flags_obj):
Raises: Raises:
ValueError: If fp16 is passed as it is not currently supported. ValueError: If fp16 is passed as it is not currently supported.
""" """
print(">>>>>>>>>>>>>>>>>>>>> eager: ", flags_obj.enable_eager)
if flags_obj.enable_eager: if flags_obj.enable_eager:
tf.enable_eager_execution() tf.enable_eager_execution()
...@@ -206,7 +205,7 @@ def run_cifar_with_keras(flags_obj): ...@@ -206,7 +205,7 @@ def run_cifar_with_keras(flags_obj):
def define_keras_cifar_flags(): def define_keras_cifar_flags():
flags.DEFINE_boolean(name='enable_eager', default=True, help='Enable eager?') flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
def main(_): def main(_):
......
...@@ -110,7 +110,8 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback): ...@@ -110,7 +110,8 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
if not isinstance(lr, (float, np.float32, np.float64)): if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function should be float.') raise ValueError('The output of the "schedule" function should be float.')
if lr != self.prev_lr: if lr != self.prev_lr:
tf.keras.backend.set_value(self.model.optimizer.learning_rate, lr) self.model.optimizer.learning_rate = lr # lr should be a float here
# tf.keras.backend.set_value(self.model.optimizer.learning_rate, lr)
self.prev_lr = lr self.prev_lr = lr
tf.logging.debug('Epoch %05d Batch %05d: LearningRateBatchScheduler change ' tf.logging.debug('Epoch %05d Batch %05d: LearningRateBatchScheduler change '
'learning rate to %s.', self.epochs, batch, lr) 'learning rate to %s.', self.epochs, batch, lr)
...@@ -134,11 +135,10 @@ def get_optimizer_loss_and_metrics(): ...@@ -134,11 +135,10 @@ def get_optimizer_loss_and_metrics():
def get_dist_strategy(): def get_dist_strategy():
if FLAGS.num_gpus == 1 and FLAGS.dist_strat_off: if True: # FLAGS.num_gpus == 1 and FLAGS.dist_strat_off:
print('Not using distribution strategies.') print('Not using distribution strategies.')
strategy = None strategy = None
else: else:
print(">>>>>>>>>>>>>>>>>>strategy!!!!!!! ", FLAGS.num_gpus)
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
num_gpus=FLAGS.num_gpus) num_gpus=FLAGS.num_gpus)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment