Commit b9e30b11 authored by Shining Sun's avatar Shining Sun
Browse files

bug fix

parent 6b673d03
...@@ -112,6 +112,7 @@ def run_cifar_with_keras(flags_obj): ...@@ -112,6 +112,7 @@ def run_cifar_with_keras(flags_obj):
Raises: Raises:
ValueError: If fp16 is passed as it is not currently supported. ValueError: If fp16 is passed as it is not currently supported.
""" """
print(">>>>>>>>>>>>>>>>>>>>> eager: ", flags_obj.enable_eager)
if flags_obj.enable_eager: if flags_obj.enable_eager:
tf.enable_eager_execution() tf.enable_eager_execution()
...@@ -199,13 +200,13 @@ def run_cifar_with_keras(flags_obj): ...@@ -199,13 +200,13 @@ def run_cifar_with_keras(flags_obj):
verbose=1) verbose=1)
print('Test loss:', eval_output[0]) print('Test loss:', eval_output[0])
stats = keras_common.analyze_eval_result(eval_output) stats = keras_common.analyze_fit_and_eval_result(history, eval_output)
return stats return stats
def define_keras_cifar_flags(): def define_keras_cifar_flags():
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?') flags.DEFINE_boolean(name='enable_eager', default=True, help='Enable eager?')
def main(_): def main(_):
......
...@@ -138,6 +138,7 @@ def get_dist_strategy(): ...@@ -138,6 +138,7 @@ def get_dist_strategy():
print('Not using distribution strategies.') print('Not using distribution strategies.')
strategy = None strategy = None
else: else:
print(">>>>>>>>>>>>>>>>>>strategy!!!!!!! ", FLAGS.num_gpus)
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
num_gpus=FLAGS.num_gpus) num_gpus=FLAGS.num_gpus)
...@@ -158,7 +159,7 @@ def get_fit_callbacks(learning_rate_schedule_fn): ...@@ -158,7 +159,7 @@ def get_fit_callbacks(learning_rate_schedule_fn):
return time_callback, tensorboard_callback, lr_callback return time_callback, tensorboard_callback, lr_callback
def analyze_eval_result(eval_output): def analyze_fit_and_eval_result(history, eval_output):
stats = {} stats = {}
stats['accuracy_top_1'] = eval_output[1] stats['accuracy_top_1'] = eval_output[1]
stats['eval_loss'] = eval_output[0] stats['eval_loss'] = eval_output[0]
......
...@@ -178,23 +178,24 @@ def run_imagenet_with_keras(flags_obj): ...@@ -178,23 +178,24 @@ def run_imagenet_with_keras(flags_obj):
num_eval_steps = (imagenet_main._NUM_IMAGES['validation'] // num_eval_steps = (imagenet_main._NUM_IMAGES['validation'] //
flags_obj.batch_size) flags_obj.batch_size)
model.fit(train_input_dataset, history = model.fit(train_input_dataset,
epochs=flags_obj.train_epochs, epochs=flags_obj.train_epochs,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
callbacks=[ callbacks=[
time_callback, time_callback,
lr_callback, lr_callback,
tensorboard_callback tensorboard_callback
], ],
validation_steps=num_eval_steps, validation_steps=num_eval_steps,
validation_data=eval_input_dataset, validation_data=eval_input_dataset,
verbose=1) verbose=1)
eval_output = model.evaluate(eval_input_dataset, eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps, steps=num_eval_steps,
verbose=1) verbose=1)
print('Test loss:', eval_output[0]) print('Test loss:', eval_output[0])
stats = keras_common.analyze_eval_result(eval_output) stats = keras_common.analyze_fit_and_eval_result(history, eval_output)
return stats return stats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment