Unverified Commit f6c2d9f8 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Support pure eager execution in ResNet50 (#6929)

* Support pure eager execution in ResNet50

* Use smaller batch size
parent 15db2195
...@@ -279,6 +279,9 @@ def define_keras_flags(): ...@@ -279,6 +279,9 @@ def define_keras_flags():
"""Define flags for Keras models.""" """Define flags for Keras models."""
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?') flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
flags.DEFINE_boolean(
name='run_eagerly', default=False,
help='Run the model op by op without building a model function.')
flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?') flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?')
flags.DEFINE_boolean(name='use_trivial_model', default=False, flags.DEFINE_boolean(name='use_trivial_model', default=False,
help='Whether to use a trivial Keras model.') help='Whether to use a trivial Keras model.')
......
...@@ -192,6 +192,19 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -192,6 +192,19 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.batch_size = 128 FLAGS.batch_size = 128
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
"""Test Keras model with 1 GPU, no distribution strategy, run eagerly."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = True
FLAGS.run_eagerly = True
FLAGS.distribution_strategy = 'off'
FLAGS.model_dir = self._get_model_dir(
'benchmark_1_gpu_no_dist_strat_run_eagerly')
FLAGS.batch_size = 64
self._run_and_report_benchmark()
def benchmark_graph_1_gpu_no_dist_strat(self): def benchmark_graph_1_gpu_no_dist_strat(self):
"""Test Keras model in legacy graph mode with 1 GPU, no dist strat.""" """Test Keras model in legacy graph mode with 1 GPU, no dist strat."""
self._setup() self._setup()
......
...@@ -222,6 +222,7 @@ def run(flags_obj): ...@@ -222,6 +222,7 @@ def run(flags_obj):
optimizer=optimizer, optimizer=optimizer,
metrics=(['sparse_categorical_accuracy'] metrics=(['sparse_categorical_accuracy']
if flags_obj.report_accuracy_metrics else None), if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly,
cloning=flags_obj.clone_model_in_keras_dist_strat) cloning=flags_obj.clone_model_in_keras_dist_strat)
callbacks = keras_common.get_callbacks( callbacks = keras_common.get_callbacks(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment