Unverified Commit dd5a91d3 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Improve Keras graph performance for ResNet56 (#7241)

* Config threadpool, cuDNN persistent BN, and grappler layout optimizer properly for ResNet56

* Add tweaked tests for Resnet56

* Avoid triggering the last partial batch overhead by explicitly dropping remainder
parent b7221961
......@@ -114,7 +114,8 @@ def input_fn(is_training,
dtype=tf.float32,
datasets_num_private_threads=None,
parse_record_fn=parse_record,
input_context=None):
input_context=None,
drop_remainder=False):
"""Input function which provides batches for train or eval.
Args:
......@@ -127,6 +128,8 @@ def input_fn(is_training,
parse_record_fn: Function to use for parsing the records.
input_context: A `tf.distribute.InputContext` object passed in by
`tf.distribute.Strategy`.
drop_remainder: A boolean indicates whether to drop the remainder of the
batches. If True, the batch dimension will be static.
Returns:
A dataset that can be used for iteration.
......@@ -149,7 +152,8 @@ def input_fn(is_training,
parse_record_fn=parse_record_fn,
num_epochs=num_epochs,
dtype=dtype,
datasets_num_private_threads=datasets_num_private_threads
datasets_num_private_threads=datasets_num_private_threads,
drop_remainder=drop_remainder
)
......
......@@ -99,8 +99,16 @@ def run(flags_obj):
Returns:
Dictionary of training and eval stats.
"""
keras_utils.set_session_config(enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla)
keras_utils.set_session_config(
enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla,
enable_grappler_layout_optimizer=
flags_obj.enable_grappler_layout_optimizer)
# Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode:
keras_common.set_gpu_thread_mode_and_count(flags_obj)
keras_common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16':
......@@ -120,6 +128,14 @@ def run(flags_obj):
all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs)
if strategy:
# flags_obj.enable_get_next_as_optional controls whether enabling
# get_next_as_optional behavior in DistributedIterator. If true, last
# partial batch can be supported.
strategy.extended.experimental_enable_get_next_as_optional = (
flags_obj.enable_get_next_as_optional
)
strategy_scope = distribution_utils.get_strategy_scope(strategy)
if flags_obj.use_synthetic_data:
......@@ -129,7 +145,8 @@ def run(flags_obj):
width=cifar_main.WIDTH,
num_channels=cifar_main.NUM_CHANNELS,
num_classes=cifar_main.NUM_CLASSES,
dtype=flags_core.get_tf_dtype(flags_obj))
dtype=flags_core.get_tf_dtype(flags_obj),
drop_remainder=True)
else:
distribution_utils.undo_set_up_synthetic_data()
input_fn = cifar_main.input_fn
......@@ -141,7 +158,11 @@ def run(flags_obj):
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras,
datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype)
dtype=dtype,
# Setting drop_remainder to avoid the partial batch logic in normalization
# layer, which triggers tf.where and leads to extra memory copy of input
# sizes between host and GPU.
drop_remainder=(not flags_obj.enable_get_next_as_optional))
eval_input_dataset = None
if not flags_obj.skip_eval:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment