Commit 51fc02ae authored by Shining Sun's avatar Shining Sun
Browse files

Add skip_eval flag and change to new optimizer

parent 6f881f77
...@@ -179,9 +179,10 @@ def run(flags_obj): ...@@ -179,9 +179,10 @@ def run(flags_obj):
validation_data=eval_input_dataset, validation_data=eval_input_dataset,
verbose=1) verbose=1)
eval_output = model.evaluate(eval_input_dataset, if not flags_obj.skip_eval:
steps=num_eval_steps, eval_output = model.evaluate(eval_input_dataset,
verbose=1) steps=num_eval_steps,
verbose=1)
stats = keras_common.analyze_fit_and_eval_result(history, eval_output) stats = keras_common.analyze_fit_and_eval_result(history, eval_output)
......
...@@ -105,7 +105,7 @@ def get_optimizer(): ...@@ -105,7 +105,7 @@ def get_optimizer():
learning_rate = BASE_LEARNING_RATE * FLAGS.batch_size / 256 learning_rate = BASE_LEARNING_RATE * FLAGS.batch_size / 256
optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
else: else:
optimizer = gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9) optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)
return optimizer return optimizer
...@@ -138,6 +138,8 @@ def analyze_fit_and_eval_result(history, eval_output): ...@@ -138,6 +138,8 @@ def analyze_fit_and_eval_result(history, eval_output):
def define_keras_flags(): def define_keras_flags():
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?') flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?')
flags.DEFINE_integer( flags.DEFINE_integer(
name="train_steps", default=None, name="train_steps", default=None,
help="The number of steps to run for training") help="The number of steps to run for training")
...@@ -172,9 +172,10 @@ def run_imagenet_with_keras(flags_obj): ...@@ -172,9 +172,10 @@ def run_imagenet_with_keras(flags_obj):
validation_data=eval_input_dataset, validation_data=eval_input_dataset,
verbose=1) verbose=1)
eval_output = model.evaluate(eval_input_dataset, if not flags_obj.skip_eval:
steps=num_eval_steps, eval_output = model.evaluate(eval_input_dataset,
verbose=1) steps=num_eval_steps,
verbose=1)
stats = keras_common.analyze_fit_and_eval_result(history, eval_output) stats = keras_common.analyze_fit_and_eval_result(history, eval_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment