Commit 59788849 authored by Priya Gupta's avatar Priya Gupta
Browse files

Add option to run eager; make 1 GPU case run without DS

parent b9d2b1bb
...@@ -190,6 +190,9 @@ def run_cifar_with_keras(flags_obj): ...@@ -190,6 +190,9 @@ def run_cifar_with_keras(flags_obj):
Raises: Raises:
ValueError: If fp16 is passed as it is not currently supported. ValueError: If fp16 is passed as it is not currently supported.
""" """
if flags_obj.enable_eager:
tf.enable_eager_execution()
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16': if dtype == 'fp16':
raise ValueError('dtype fp16 is not supported in Keras. Use the default ' raise ValueError('dtype fp16 is not supported in Keras. Use the default '
...@@ -262,11 +265,16 @@ def run_cifar_with_keras(flags_obj): ...@@ -262,11 +265,16 @@ def run_cifar_with_keras(flags_obj):
loss = 'categorical_crossentropy' loss = 'categorical_crossentropy'
accuracy = 'categorical_accuracy' accuracy = 'categorical_accuracy'
model.compile(loss=loss, if flags_obj.num_gpus == 1:
optimizer=opt, model.compile(loss=loss,
metrics=[accuracy], optimizer=opt,
distribute=strategy) metrics=[accuracy])
else:
model.compile(loss=loss,
optimizer=opt,
metrics=[accuracy],
distribute=strategy)
steps_per_epoch = cifar_main._NUM_IMAGES['train'] // flags_obj.batch_size steps_per_epoch = cifar_main._NUM_IMAGES['train'] // flags_obj.batch_size
...@@ -283,7 +291,8 @@ def run_cifar_with_keras(flags_obj): ...@@ -283,7 +291,8 @@ def run_cifar_with_keras(flags_obj):
num_eval_steps = (cifar_main._NUM_IMAGES['validation'] // num_eval_steps = (cifar_main._NUM_IMAGES['validation'] //
flags_obj.batch_size) flags_obj.batch_size)
print("Executing eagerly?:", tf.executing_eagerly())
model.fit(train_input_dataset, model.fit(train_input_dataset,
epochs=flags_obj.train_epochs, epochs=flags_obj.train_epochs,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
...@@ -300,6 +309,7 @@ def run_cifar_with_keras(flags_obj): ...@@ -300,6 +309,7 @@ def run_cifar_with_keras(flags_obj):
print('Test loss:', eval_output[0]) print('Test loss:', eval_output[0])
def main(_): def main(_):
with logger.benchmark_context(flags.FLAGS): with logger.benchmark_context(flags.FLAGS):
run_cifar_with_keras(flags.FLAGS) run_cifar_with_keras(flags.FLAGS)
...@@ -307,4 +317,5 @@ def main(_): ...@@ -307,4 +317,5 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.DEBUG) tf.logging.set_verbosity(tf.logging.DEBUG)
cifar_main.define_cifar_flags() cifar_main.define_cifar_flags()
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
absl_app.run(main) absl_app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment