Commit b9e30b11 authored by Shining Sun's avatar Shining Sun
Browse files

bug fix

parent 6b673d03
......@@ -112,6 +112,7 @@ def run_cifar_with_keras(flags_obj):
Raises:
ValueError: If fp16 is passed as it is not currently supported.
"""
print(">>>>>>>>>>>>>>>>>>>>> eager: ", flags_obj.enable_eager)
if flags_obj.enable_eager:
tf.enable_eager_execution()
......@@ -199,13 +200,13 @@ def run_cifar_with_keras(flags_obj):
verbose=1)
print('Test loss:', eval_output[0])
stats = keras_common.analyze_eval_result(eval_output)
stats = keras_common.analyze_fit_and_eval_result(history, eval_output)
return stats
def define_keras_cifar_flags():
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
flags.DEFINE_boolean(name='enable_eager', default=True, help='Enable eager?')
def main(_):
......
......@@ -138,6 +138,7 @@ def get_dist_strategy():
print('Not using distribution strategies.')
strategy = None
else:
print(">>>>>>>>>>>>>>>>>>strategy!!!!!!! ", FLAGS.num_gpus)
strategy = distribution_utils.get_distribution_strategy(
num_gpus=FLAGS.num_gpus)
......@@ -158,7 +159,7 @@ def get_fit_callbacks(learning_rate_schedule_fn):
return time_callback, tensorboard_callback, lr_callback
def analyze_eval_result(eval_output):
def analyze_fit_and_eval_result(history, eval_output):
stats = {}
stats['accuracy_top_1'] = eval_output[1]
stats['eval_loss'] = eval_output[0]
......
......@@ -178,23 +178,24 @@ def run_imagenet_with_keras(flags_obj):
num_eval_steps = (imagenet_main._NUM_IMAGES['validation'] //
flags_obj.batch_size)
model.fit(train_input_dataset,
epochs=flags_obj.train_epochs,
steps_per_epoch=steps_per_epoch,
callbacks=[
time_callback,
lr_callback,
tensorboard_callback
],
validation_steps=num_eval_steps,
validation_data=eval_input_dataset,
verbose=1)
history = model.fit(train_input_dataset,
epochs=flags_obj.train_epochs,
steps_per_epoch=steps_per_epoch,
callbacks=[
time_callback,
lr_callback,
tensorboard_callback
],
validation_steps=num_eval_steps,
validation_data=eval_input_dataset,
verbose=1)
eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps,
verbose=1)
print('Test loss:', eval_output[0])
stats = keras_common.analyze_eval_result(eval_output)
stats = keras_common.analyze_fit_and_eval_result(history, eval_output)
return stats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment