Unverified Commit dd5a91d3 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Improve Keras graph performance for ResNet56 (#7241)

* Config threadpool, cuDNN persistent BN, and grappler layout optimizer properly for ResNet56

* Add tweaked tests for Resnet56

* Avoid triggering the last partial batch overhead by explicitly dropping remainder
parent b7221961
...@@ -114,7 +114,8 @@ def input_fn(is_training, ...@@ -114,7 +114,8 @@ def input_fn(is_training,
dtype=tf.float32, dtype=tf.float32,
datasets_num_private_threads=None, datasets_num_private_threads=None,
parse_record_fn=parse_record, parse_record_fn=parse_record,
input_context=None): input_context=None,
drop_remainder=False):
"""Input function which provides batches for train or eval. """Input function which provides batches for train or eval.
Args: Args:
...@@ -127,6 +128,8 @@ def input_fn(is_training, ...@@ -127,6 +128,8 @@ def input_fn(is_training,
parse_record_fn: Function to use for parsing the records. parse_record_fn: Function to use for parsing the records.
input_context: A `tf.distribute.InputContext` object passed in by input_context: A `tf.distribute.InputContext` object passed in by
`tf.distribute.Strategy`. `tf.distribute.Strategy`.
drop_remainder: A boolean indicates whether to drop the remainder of the
batches. If True, the batch dimension will be static.
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
...@@ -149,7 +152,8 @@ def input_fn(is_training, ...@@ -149,7 +152,8 @@ def input_fn(is_training,
parse_record_fn=parse_record_fn, parse_record_fn=parse_record_fn,
num_epochs=num_epochs, num_epochs=num_epochs,
dtype=dtype, dtype=dtype,
datasets_num_private_threads=datasets_num_private_threads datasets_num_private_threads=datasets_num_private_threads,
drop_remainder=drop_remainder
) )
......
...@@ -99,8 +99,16 @@ def run(flags_obj): ...@@ -99,8 +99,16 @@ def run(flags_obj):
Returns: Returns:
Dictionary of training and eval stats. Dictionary of training and eval stats.
""" """
keras_utils.set_session_config(enable_eager=flags_obj.enable_eager, keras_utils.set_session_config(
enable_xla=flags_obj.enable_xla) enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla,
enable_grappler_layout_optimizer=
flags_obj.enable_grappler_layout_optimizer)
# Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode:
keras_common.set_gpu_thread_mode_and_count(flags_obj)
keras_common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16': if dtype == 'fp16':
...@@ -120,6 +128,14 @@ def run(flags_obj): ...@@ -120,6 +128,14 @@ def run(flags_obj):
all_reduce_alg=flags_obj.all_reduce_alg, all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs) num_packs=flags_obj.num_packs)
if strategy:
# flags_obj.enable_get_next_as_optional controls whether enabling
# get_next_as_optional behavior in DistributedIterator. If true, last
# partial batch can be supported.
strategy.extended.experimental_enable_get_next_as_optional = (
flags_obj.enable_get_next_as_optional
)
strategy_scope = distribution_utils.get_strategy_scope(strategy) strategy_scope = distribution_utils.get_strategy_scope(strategy)
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
...@@ -129,7 +145,8 @@ def run(flags_obj): ...@@ -129,7 +145,8 @@ def run(flags_obj):
width=cifar_main.WIDTH, width=cifar_main.WIDTH,
num_channels=cifar_main.NUM_CHANNELS, num_channels=cifar_main.NUM_CHANNELS,
num_classes=cifar_main.NUM_CLASSES, num_classes=cifar_main.NUM_CLASSES,
dtype=flags_core.get_tf_dtype(flags_obj)) dtype=flags_core.get_tf_dtype(flags_obj),
drop_remainder=True)
else: else:
distribution_utils.undo_set_up_synthetic_data() distribution_utils.undo_set_up_synthetic_data()
input_fn = cifar_main.input_fn input_fn = cifar_main.input_fn
...@@ -141,7 +158,11 @@ def run(flags_obj): ...@@ -141,7 +158,11 @@ def run(flags_obj):
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras, parse_record_fn=parse_record_keras,
datasets_num_private_threads=flags_obj.datasets_num_private_threads, datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype) dtype=dtype,
# Setting drop_remainder to avoid the partial batch logic in normalization
# layer, which triggers tf.where and leads to extra memory copy of input
# sizes between host and GPU.
drop_remainder=(not flags_obj.enable_get_next_as_optional))
eval_input_dataset = None eval_input_dataset = None
if not flags_obj.skip_eval: if not flags_obj.skip_eval:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment