Unverified Commit a68f65f8 authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

NCF XLA and Eager tests with a refactor of resnet flags to make this cleaner. (#7067)

* XLA FP32 and first test

* More XLA benchmarks FP32.

* Add eager to NCF and refactor resnet.

* fix v2_0 calls and more flag refactor.

* Remove extra flag args.

* 90 epoch default

* add return

* remove xla not used by estimator.

* Remove duplicate run_eagerly.

* fix flag defaults.

* Remove fp16_implementation flag option.

* Remove stop early on mlperf test.

* remove unneeded args.

* load flags from keras mains.
parent b578aee9
...@@ -35,6 +35,7 @@ from official.recommendation import data_pipeline ...@@ -35,6 +35,7 @@ from official.recommendation import data_pipeline
from official.recommendation import data_preprocessing from official.recommendation import data_preprocessing
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -152,7 +153,7 @@ def get_distribution_strategy(params): ...@@ -152,7 +153,7 @@ def get_distribution_strategy(params):
def define_ncf_flags(): def define_ncf_flags():
"""Add flags for running ncf_main.""" """Add flags for running ncf_main."""
# Add common flags # Add common flags
flags_core.define_base(export_dir=False) flags_core.define_base(export_dir=False, run_eagerly=True)
flags_core.define_performance( flags_core.define_performance(
num_parallel_calls=False, num_parallel_calls=False,
inter_op=False, inter_op=False,
...@@ -160,7 +161,8 @@ def define_ncf_flags(): ...@@ -160,7 +161,8 @@ def define_ncf_flags():
synthetic_data=True, synthetic_data=True,
max_train_steps=False, max_train_steps=False,
dtype=False, dtype=False,
all_reduce_alg=False all_reduce_alg=False,
enable_xla=True
) )
flags_core.define_device(tpu=True) flags_core.define_device(tpu=True)
flags_core.define_benchmark() flags_core.define_benchmark()
...@@ -318,14 +320,13 @@ def define_ncf_flags(): ...@@ -318,14 +320,13 @@ def define_ncf_flags():
def convert_to_softmax_logits(logits): def convert_to_softmax_logits(logits):
'''Convert the logits returned by the base model to softmax logits. """Convert the logits returned by the base model to softmax logits.
Softmax with the first column of zeros is equivalent to sigmoid. Args:
''' logits: used to create softmax.
Returns:
Softmax with the first column of zeros is equivalent to sigmoid.
"""
softmax_logits = tf.concat([logits * 0, logits], axis=1) softmax_logits = tf.concat([logits * 0, logits], axis=1)
return softmax_logits return softmax_logits
def is_tf_v2():
"""Returns whether it is v2."""
from tensorflow.python import tf2 as tf2_internal
return tf2_internal.enabled()
...@@ -132,6 +132,19 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -132,6 +132,19 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS.early_stopping = True FLAGS.early_stopping = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_1_gpu_no_dist_strat_run_eagerly_early_stop(self):
self._setup()
FLAGS.distribution_strategy = 'off'
FLAGS.early_stopping = True
FLAGS.run_eagerly = True
self._run_and_report_benchmark()
def benchmark_xla_1_gpu_early_stop(self):
self._setup()
FLAGS.early_stopping = True
FLAGS.enable_xla = True
self._run_and_report_benchmark()
# NCF with custom training loop. Works only in TF 2.0 # NCF with custom training loop. Works only in TF 2.0
def benchmark_1_gpu_ctl(self): def benchmark_1_gpu_ctl(self):
self._setup() self._setup()
...@@ -145,6 +158,13 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -145,6 +158,13 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS.early_stopping = True FLAGS.early_stopping = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_xla_1_gpu_ctl_early_stop(self):
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.early_stopping = True
FLAGS.enable_xla = True
self._run_and_report_benchmark()
def benchmark_2_gpus(self): def benchmark_2_gpus(self):
self._setup() self._setup()
FLAGS.num_gpus = 2 FLAGS.num_gpus = 2
...@@ -156,15 +176,15 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -156,15 +176,15 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS.num_gpus = 2 FLAGS.num_gpus = 2
self._run_and_report_benchmark() self._run_and_report_benchmark()
# NCF with custom training loop. Works only in TF 2.0
def benchmark_2_gpus_ctl(self): def benchmark_2_gpus_ctl(self):
"""NCF with custom training loop. Works only in TF 2.0."""
self._setup() self._setup()
FLAGS.keras_use_ctl = True FLAGS.keras_use_ctl = True
FLAGS.num_gpus = 2 FLAGS.num_gpus = 2
self._run_and_report_benchmark() self._run_and_report_benchmark()
# NCF with custom training loop. Works only in TF 2.0
def benchmark_2_gpus_ctl_early_stop(self): def benchmark_2_gpus_ctl_early_stop(self):
"""NCF with custom training loop. Works only in TF 2.0."""
self._setup() self._setup()
FLAGS.keras_use_ctl = True FLAGS.keras_use_ctl = True
FLAGS.early_stopping = True FLAGS.early_stopping = True
...@@ -172,11 +192,12 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -172,11 +192,12 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_1_gpu_ctl_mlperf_like(self): def benchmark_1_gpu_ctl_mlperf_like(self):
"""1-GPU test to compare Google implementation with MLperf0.5. """1-GPU test to compare Google implementation with MLPerf 0.5.
Using similar rules as MLPerf0.5
Using Google's convergence hparams as base for 1-GPU test. Using similar rules as MLPerf 0.5
Fixed the number of epochs to 7, to remove the perf variance. - Using Google's convergence hparams as base for 1-GPU test.
MLPerf submission consistently converges in 7 epochs. - Fixed the number of epochs to 7, to remove the perf variance.
- MLPerf submission consistently converges in 7 epochs.
""" """
self._setup() self._setup()
FLAGS.keras_use_ctl = True FLAGS.keras_use_ctl = True
...@@ -184,17 +205,39 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -184,17 +205,39 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_1_gpu_mlperf_like(self): def benchmark_1_gpu_mlperf_like(self):
"""1-GPU MLPerf like test with compile/fit version""" """1-GPU MLPerf like test with compile/fit version."""
self._setup()
FLAGS.train_epochs = 7
self._run_and_report_benchmark()
def benchmark_1_gpu_no_dist_strat_mlperf_like(self):
"""1-GPU MLPerf like test with compile/fit version without dist_strat."""
self._setup()
FLAGS.train_epochs = 7
FLAGS.distribution_strategy = 'off'
self._run_and_report_benchmark()
def benchmark_1_gpu_no_dist_strat_run_eagerly_mlperf_like(self):
self._setup()
FLAGS.train_epochs = 7
FLAGS.distribution_strategy = 'off'
FLAGS.run_eagerly = True
self._run_and_report_benchmark()
def benchmark_xla_1_gpu_mlperf_like(self):
"""1-GPU MLPerf like test with compile/fit version w/xla."""
self._setup() self._setup()
FLAGS.train_epochs = 7 FLAGS.train_epochs = 7
FLAGS.enable_xla = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu_ctl_mlperf_like(self): def benchmark_8_gpu_ctl_mlperf_like(self):
"""8 GPU test meant to compare Google implementation """8 GPU test meant to compare Google implementation.
with MLperf top line submission using the
hyper-parameters from the winning MLPerf0.5 submission. MLPerf 0.5 top line submission using the
Using similar rules as MLPerf0.5 - hyper-parameters from the winning MLPerf0.5 submission.
Fixed epochs to MLPerf sumbmission's convergnce on 17 epochs - Using similar rules as MLPerf0.5
- Fixed epochs to MLPerf submission's convergence on 17 epochs
""" """
self._setup() self._setup()
FLAGS.keras_use_ctl = True FLAGS.keras_use_ctl = True
......
...@@ -87,7 +87,7 @@ def _get_train_and_eval_data(producer, params): ...@@ -87,7 +87,7 @@ def _get_train_and_eval_data(producer, params):
features[rconst.DUPLICATE_MASK] = fake_dup_mask features[rconst.DUPLICATE_MASK] = fake_dup_mask
features[rconst.TRAIN_LABEL_KEY] = labels features[rconst.TRAIN_LABEL_KEY] = labels
if params["distribute_strategy"] or not ncf_common.is_tf_v2(): if params["distribute_strategy"] or not keras_utils.is_v2_0():
return features return features
else: else:
# b/134708104 # b/134708104
...@@ -100,7 +100,7 @@ def _get_train_and_eval_data(producer, params): ...@@ -100,7 +100,7 @@ def _get_train_and_eval_data(producer, params):
def preprocess_eval_input(features): def preprocess_eval_input(features):
"""Pre-process the eval data. """Pre-process the eval data.
This is needed becasue: This is needed because:
- The label needs to be extended to be used in the loss fn - The label needs to be extended to be used in the loss fn
- We need the same inputs for training and eval so adding fake inputs - We need the same inputs for training and eval so adding fake inputs
for VALID_PT_MASK in eval data. for VALID_PT_MASK in eval data.
...@@ -112,7 +112,7 @@ def _get_train_and_eval_data(producer, params): ...@@ -112,7 +112,7 @@ def _get_train_and_eval_data(producer, params):
features[rconst.VALID_POINT_MASK] = fake_valid_pt_mask features[rconst.VALID_POINT_MASK] = fake_valid_pt_mask
features[rconst.TRAIN_LABEL_KEY] = labels features[rconst.TRAIN_LABEL_KEY] = labels
if params["distribute_strategy"] or not ncf_common.is_tf_v2(): if params["distribute_strategy"] or not keras_utils.is_v2_0():
return features return features
else: else:
# b/134708104 # b/134708104
...@@ -251,6 +251,8 @@ def _get_keras_model(params): ...@@ -251,6 +251,8 @@ def _get_keras_model(params):
def run_ncf(_): def run_ncf(_):
"""Run NCF training and eval with Keras.""" """Run NCF training and eval with Keras."""
keras_utils.set_session_config(enable_xla=FLAGS.enable_xla)
if FLAGS.seed is not None: if FLAGS.seed is not None:
print("Setting tf seed") print("Setting tf seed")
tf.random.set_seed(FLAGS.seed) tf.random.set_seed(FLAGS.seed)
...@@ -272,7 +274,7 @@ def run_ncf(_): ...@@ -272,7 +274,7 @@ def run_ncf(_):
params["distribute_strategy"] = strategy params["distribute_strategy"] = strategy
if (params["keras_use_ctl"] and ( if (params["keras_use_ctl"] and (
not ncf_common.is_tf_v2() or strategy is None)): not keras_utils.is_v2_0() or strategy is None)):
logging.error( logging.error(
"Custom training loop only works with tensorflow 2.0 and dist strat.") "Custom training loop only works with tensorflow 2.0 and dist strat.")
return return
...@@ -398,7 +400,8 @@ def run_ncf(_): ...@@ -398,7 +400,8 @@ def run_ncf(_):
else: else:
with distribution_utils.get_strategy_scope(strategy): with distribution_utils.get_strategy_scope(strategy):
keras_model.compile(optimizer=optimizer) keras_model.compile(optimizer=optimizer,
run_eagerly=FLAGS.run_eagerly)
history = keras_model.fit(train_input_dataset, history = keras_model.fit(train_input_dataset,
epochs=FLAGS.train_epochs, epochs=FLAGS.train_epochs,
......
...@@ -128,10 +128,7 @@ class Resnet50EstimatorAccuracy(EstimatorBenchmark): ...@@ -128,10 +128,7 @@ class Resnet50EstimatorAccuracy(EstimatorBenchmark):
constructor forward compatible in case PerfZero provides more constructor forward compatible in case PerfZero provides more
named arguments before updating the constructor. named arguments before updating the constructor.
""" """
flag_methods = [ flag_methods = [imagenet_main.define_imagenet_flags]
lambda: imagenet_main.define_imagenet_flags(dynamic_loss_scale=True,
fp16_implementation=True)
]
self.data_dir = os.path.join(root_data_dir, IMAGENET_DATA_DIR_NAME) self.data_dir = os.path.join(root_data_dir, IMAGENET_DATA_DIR_NAME)
super(Resnet50EstimatorAccuracy, self).__init__( super(Resnet50EstimatorAccuracy, self).__init__(
...@@ -193,10 +190,7 @@ class Resnet50EstimatorBenchmarkBase(EstimatorBenchmark): ...@@ -193,10 +190,7 @@ class Resnet50EstimatorBenchmarkBase(EstimatorBenchmark):
local_flags = None local_flags = None
def __init__(self, output_dir=None, default_flags=None): def __init__(self, output_dir=None, default_flags=None):
flag_methods = [ flag_methods = [imagenet_main.define_imagenet_flags]
lambda: imagenet_main.define_imagenet_flags(dynamic_loss_scale=True,
fp16_implementation=True)
]
super(Resnet50EstimatorBenchmarkBase, self).__init__( super(Resnet50EstimatorBenchmarkBase, self).__init__(
output_dir=output_dir, output_dir=output_dir,
......
...@@ -348,14 +348,11 @@ def imagenet_model_fn(features, labels, mode, params): ...@@ -348,14 +348,11 @@ def imagenet_model_fn(features, labels, mode, params):
) )
def define_imagenet_flags(dynamic_loss_scale=False, def define_imagenet_flags():
fp16_implementation=False,
enable_xla=False):
resnet_run_loop.define_resnet_flags( resnet_run_loop.define_resnet_flags(
resnet_size_choices=['18', '34', '50', '101', '152', '200'], resnet_size_choices=['18', '34', '50', '101', '152', '200'],
dynamic_loss_scale=dynamic_loss_scale, dynamic_loss_scale=True,
fp16_implementation=fp16_implementation, fp16_implementation=True)
enable_xla=enable_xla)
flags.adopt_module_key_flags(resnet_run_loop) flags.adopt_module_key_flags(resnet_run_loop)
flags_core.set_defaults(train_epochs=90) flags_core.set_defaults(train_epochs=90)
...@@ -390,5 +387,5 @@ def main(_): ...@@ -390,5 +387,5 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
define_imagenet_flags(dynamic_loss_scale=True, fp16_implementation=True) define_imagenet_flags()
absl_app.run(main) absl_app.run(main)
...@@ -49,9 +49,7 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -49,9 +49,7 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
""" """
self.data_dir = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME) self.data_dir = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
flag_methods = [ flag_methods = [keras_cifar_main.define_cifar_flags]
keras_common.define_keras_flags, cifar_main.define_cifar_flags
]
super(Resnet56KerasAccuracy, self).__init__( super(Resnet56KerasAccuracy, self).__init__(
output_dir=output_dir, flag_methods=flag_methods) output_dir=output_dir, flag_methods=flag_methods)
...@@ -161,9 +159,7 @@ class Resnet56KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -161,9 +159,7 @@ class Resnet56KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
"""Short performance tests for ResNet56 via Keras and CIFAR-10.""" """Short performance tests for ResNet56 via Keras and CIFAR-10."""
def __init__(self, output_dir=None, default_flags=None): def __init__(self, output_dir=None, default_flags=None):
flag_methods = [ flag_methods = [keras_cifar_main.define_cifar_flags]
keras_common.define_keras_flags, cifar_main.define_cifar_flags
]
super(Resnet56KerasBenchmarkBase, self).__init__( super(Resnet56KerasBenchmarkBase, self).__init__(
output_dir=output_dir, output_dir=output_dir,
......
...@@ -99,7 +99,8 @@ def run(flags_obj): ...@@ -99,7 +99,8 @@ def run(flags_obj):
Returns: Returns:
Dictionary of training and eval stats. Dictionary of training and eval stats.
""" """
keras_utils.set_session_config(enable_eager=flags_obj.enable_eager) keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla)
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16': if dtype == 'fp16':
...@@ -202,6 +203,16 @@ def run(flags_obj): ...@@ -202,6 +203,16 @@ def run(flags_obj):
return stats return stats
def define_cifar_flags():
keras_common.define_keras_flags(dynamic_loss_scale=False)
flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin',
model_dir='/tmp/cifar10_model',
train_epochs=182,
epochs_between_evals=10,
batch_size=128)
def main(_): def main(_):
with logger.benchmark_context(flags.FLAGS): with logger.benchmark_context(flags.FLAGS):
return run(flags.FLAGS) return run(flags.FLAGS)
...@@ -209,6 +220,5 @@ def main(_): ...@@ -209,6 +220,5 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
cifar_main.define_cifar_flags() define_cifar_flags()
keras_common.define_keras_flags()
absl_app.run(main) absl_app.run(main)
...@@ -35,9 +35,9 @@ class KerasCifarTest(googletest.TestCase): ...@@ -35,9 +35,9 @@ class KerasCifarTest(googletest.TestCase):
"""Unit tests for Keras ResNet with Cifar.""" """Unit tests for Keras ResNet with Cifar."""
_extra_flags = [ _extra_flags = [
'-batch_size', '4', "-batch_size", "4",
'-train_steps', '1', "-train_steps", "1",
'-use_synthetic_data', 'true' "-use_synthetic_data", "true"
] ]
_tempdir = None _tempdir = None
...@@ -49,12 +49,11 @@ class KerasCifarTest(googletest.TestCase): ...@@ -49,12 +49,11 @@ class KerasCifarTest(googletest.TestCase):
@classmethod @classmethod
def setUpClass(cls): # pylint: disable=invalid-name def setUpClass(cls): # pylint: disable=invalid-name
super(KerasCifarTest, cls).setUpClass() super(KerasCifarTest, cls).setUpClass()
cifar10_main.define_cifar_flags() keras_cifar_main.define_cifar_flags()
keras_common.define_keras_flags()
def setUp(self): def setUp(self):
super(KerasCifarTest, self).setUp() super(KerasCifarTest, self).setUp()
cifar10_main.NUM_IMAGES['validation'] = 4 cifar10_main.NUM_IMAGES["validation"] = 4
def tearDown(self): def tearDown(self):
super(KerasCifarTest, self).tearDown() super(KerasCifarTest, self).tearDown()
......
...@@ -26,6 +26,7 @@ import numpy as np ...@@ -26,6 +26,7 @@ import numpy as np
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
# pylint: disable=ungrouped-imports # pylint: disable=ungrouped-imports
from tensorflow.python.keras.optimizer_v2 import (gradient_descent as from tensorflow.python.keras.optimizer_v2 import (gradient_descent as
...@@ -248,13 +249,21 @@ def build_stats(history, eval_output, callbacks): ...@@ -248,13 +249,21 @@ def build_stats(history, eval_output, callbacks):
return stats return stats
def define_keras_flags(): def define_keras_flags(dynamic_loss_scale=True):
"""Define flags for Keras models.""" """Define flags for Keras models."""
flags_core.define_base(run_eagerly=True)
flags_core.define_performance(num_parallel_calls=False,
tf_gpu_thread_mode=True,
datasets_num_private_threads=True,
dynamic_loss_scale=dynamic_loss_scale,
loss_scale=True,
tf_data_experimental_slack=True,
enable_xla=True)
flags_core.define_image()
flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core)
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?') flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
flags.DEFINE_boolean(
name='run_eagerly', default=False,
help='Run the model op by op without building a model function.')
flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?') flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?')
# TODO(b/135607288): Remove this flag once we understand the root cause of # TODO(b/135607288): Remove this flag once we understand the root cause of
# slowdown when setting the learning phase in Keras backend. # slowdown when setting the learning phase in Keras backend.
......
...@@ -262,8 +262,8 @@ def run(flags_obj): ...@@ -262,8 +262,8 @@ def run(flags_obj):
def define_imagenet_keras_flags(): def define_imagenet_keras_flags():
imagenet_main.define_imagenet_flags(dynamic_loss_scale=True, enable_xla=True)
keras_common.define_keras_flags() keras_common.define_keras_flags()
flags_core.set_defaults(train_epochs=90)
def main(_): def main(_):
......
...@@ -35,9 +35,9 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -35,9 +35,9 @@ class KerasImagenetTest(googletest.TestCase):
"""Unit tests for Keras ResNet with ImageNet.""" """Unit tests for Keras ResNet with ImageNet."""
_extra_flags = [ _extra_flags = [
'-batch_size', '4', "-batch_size", "4",
'-train_steps', '1', "-train_steps", "1",
'-use_synthetic_data', 'true' "-use_synthetic_data", "true"
] ]
_tempdir = None _tempdir = None
...@@ -53,7 +53,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -53,7 +53,7 @@ class KerasImagenetTest(googletest.TestCase):
def setUp(self): def setUp(self):
super(KerasImagenetTest, self).setUp() super(KerasImagenetTest, self).setUp()
imagenet_main.NUM_IMAGES['validation'] = 4 imagenet_main.NUM_IMAGES["validation"] = 4
def tearDown(self): def tearDown(self):
super(KerasImagenetTest, self).tearDown() super(KerasImagenetTest, self).tearDown()
......
...@@ -723,7 +723,7 @@ def resnet_main( ...@@ -723,7 +723,7 @@ def resnet_main(
def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False, def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False,
fp16_implementation=False, enable_xla=False): fp16_implementation=False):
"""Add flags and validators for ResNet.""" """Add flags and validators for ResNet."""
flags_core.define_base() flags_core.define_base()
flags_core.define_performance(num_parallel_calls=False, flags_core.define_performance(num_parallel_calls=False,
...@@ -732,8 +732,7 @@ def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False, ...@@ -732,8 +732,7 @@ def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False,
dynamic_loss_scale=dynamic_loss_scale, dynamic_loss_scale=dynamic_loss_scale,
fp16_implementation=fp16_implementation, fp16_implementation=fp16_implementation,
loss_scale=True, loss_scale=True,
tf_data_experimental_slack=True, tf_data_experimental_slack=True)
enable_xla=enable_xla)
flags_core.define_image() flags_core.define_image()
flags_core.define_benchmark() flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core) flags.adopt_module_key_flags(flags_core)
......
...@@ -28,7 +28,7 @@ from official.utils.logs import hooks_helper ...@@ -28,7 +28,7 @@ from official.utils.logs import hooks_helper
def define_base(data_dir=True, model_dir=True, clean=True, train_epochs=True, def define_base(data_dir=True, model_dir=True, clean=True, train_epochs=True,
epochs_between_evals=True, stop_threshold=True, batch_size=True, epochs_between_evals=True, stop_threshold=True, batch_size=True,
num_gpu=True, hooks=True, export_dir=True, num_gpu=True, hooks=True, export_dir=True,
distribution_strategy=True): distribution_strategy=True, run_eagerly=False):
"""Register base flags. """Register base flags.
Args: Args:
...@@ -44,6 +44,7 @@ def define_base(data_dir=True, model_dir=True, clean=True, train_epochs=True, ...@@ -44,6 +44,7 @@ def define_base(data_dir=True, model_dir=True, clean=True, train_epochs=True,
export_dir: Create a flag to specify where a SavedModel should be exported. export_dir: Create a flag to specify where a SavedModel should be exported.
distribution_strategy: Create a flag to specify which Distribution Strategy distribution_strategy: Create a flag to specify which Distribution Strategy
to use. to use.
run_eagerly: Create a flag to specify to run eagerly op by op.
Returns: Returns:
A list of flags for core.py to marks as key flags. A list of flags for core.py to marks as key flags.
""" """
...@@ -106,6 +107,11 @@ def define_base(data_dir=True, model_dir=True, clean=True, train_epochs=True, ...@@ -106,6 +107,11 @@ def define_base(data_dir=True, model_dir=True, clean=True, train_epochs=True,
"How many GPUs to use at each worker with the " "How many GPUs to use at each worker with the "
"DistributionStrategies API. The default is 1.")) "DistributionStrategies API. The default is 1."))
if run_eagerly:
flags.DEFINE_boolean(
name="run_eagerly", default=False,
help="Run the model op by op without building a model function.")
if hooks: if hooks:
# Construct a pretty summary of hooks. # Construct a pretty summary of hooks.
hook_list_str = ( hook_list_str = (
...@@ -142,6 +148,7 @@ def define_base(data_dir=True, model_dir=True, clean=True, train_epochs=True, ...@@ -142,6 +148,7 @@ def define_base(data_dir=True, model_dir=True, clean=True, train_epochs=True,
"according to the number of GPUs.") "according to the number of GPUs.")
) )
return key_flags return key_flags
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment