Unverified Commit d8a09064 authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Add run_eagerly and end-to-end test. (#7012)

parent 21072090
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
import os import os
import time import time
from absl import flags from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main as cifar_main from official.resnet import cifar10_main as cifar_main
from official.resnet.keras import keras_benchmark from official.resnet.keras import keras_benchmark
...@@ -80,6 +80,21 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -80,6 +80,21 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.enable_eager = True FLAGS.enable_eager = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
"""Test keras based model with forced eager."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.data_dir = self.data_dir
FLAGS.batch_size = 128
FLAGS.train_epochs = 182
FLAGS.model_dir = self._get_model_dir(
'benchmark_1_gpu_no_dist_strat_run_eagerly')
FLAGS.dtype = 'fp32'
FLAGS.enable_eager = True
FLAGS.run_eagerly = True
FLAGS.distribution_strategy = 'off'
self._run_and_report_benchmark()
def benchmark_2_gpu(self): def benchmark_2_gpu(self):
"""Test keras based model with eager and distribution strategies.""" """Test keras based model with eager and distribution strategies."""
self._setup() self._setup()
......
...@@ -159,6 +159,7 @@ def run(flags_obj): ...@@ -159,6 +159,7 @@ def run(flags_obj):
model.compile(loss='categorical_crossentropy', model.compile(loss='categorical_crossentropy',
optimizer=optimizer, optimizer=optimizer,
run_eagerly=flags_obj.run_eagerly,
metrics=['categorical_accuracy']) metrics=['categorical_accuracy'])
callbacks = keras_common.get_callbacks( callbacks = keras_common.get_callbacks(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment