Commit 59788849 authored by Priya Gupta's avatar Priya Gupta
Browse files

Add option to run eager; make 1 GPU case run without DS

parent b9d2b1bb
......@@ -190,6 +190,9 @@ def run_cifar_with_keras(flags_obj):
Raises:
ValueError: If fp16 is passed as it is not currently supported.
"""
if flags_obj.enable_eager:
tf.enable_eager_execution()
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16':
raise ValueError('dtype fp16 is not supported in Keras. Use the default '
......@@ -263,6 +266,11 @@ def run_cifar_with_keras(flags_obj):
loss = 'categorical_crossentropy'
accuracy = 'categorical_accuracy'
if flags_obj.num_gpus == 1:
model.compile(loss=loss,
optimizer=opt,
metrics=[accuracy])
else:
model.compile(loss=loss,
optimizer=opt,
metrics=[accuracy],
......@@ -284,6 +292,7 @@ def run_cifar_with_keras(flags_obj):
num_eval_steps = (cifar_main._NUM_IMAGES['validation'] //
flags_obj.batch_size)
print("Executing eagerly?:", tf.executing_eagerly())
model.fit(train_input_dataset,
epochs=flags_obj.train_epochs,
steps_per_epoch=steps_per_epoch,
......@@ -300,6 +309,7 @@ def run_cifar_with_keras(flags_obj):
print('Test loss:', eval_output[0])
def main(_):
with logger.benchmark_context(flags.FLAGS):
run_cifar_with_keras(flags.FLAGS)
......@@ -307,4 +317,5 @@ def main(_):
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.DEBUG)
cifar_main.define_cifar_flags()
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
absl_app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment