Unverified Commit 50dfb31d authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Add experimental tf.data sleep tuning for better performance (#6634)

* Introduce a short sleep before ds.prefetch in tf.data.

* Further limit dataset threads to reduce CPU contention

* Tuned dataset sleep time

* Rename dataset sleep flag; enable it only for Keras Graph mode
parent 0d76b69f
......@@ -150,10 +150,11 @@ def set_gpu_thread_mode_and_count(flags_obj):
# Limit data preprocessing threadpool to CPU cores minus number of total GPU
# private threads and memory copy threads.
total_gpu_thread_count = per_gpu_thread_count * flags_obj.num_gpus
num_mem_copy_threads = flags_obj.num_gpus
num_runtime_threads = flags_obj.num_gpus
if not flags_obj.datasets_num_private_threads:
flags_obj.datasets_num_private_threads = (cpu_count - total_gpu_thread_count
- num_mem_copy_threads)
flags_obj.datasets_num_private_threads = min(
cpu_count - total_gpu_thread_count - num_runtime_threads,
flags_obj.num_gpus * 8)
tf.compat.v1.logging.info('Set datasets_num_private_threads to %s',
flags_obj.datasets_num_private_threads)
......@@ -283,6 +284,13 @@ def define_keras_flags():
'triggers the profiler to process 3 steps, starting from the 2nd step. '
'Note that profiler has a non-trivial performance overhead, and the '
'output file can be gigantic if profiling many steps.')
flags.DEFINE_boolean(
name='data_prefetch_with_slack', default=False,
help='Add a small delay in tf.data prefetch to prioritize memory copy of '
'other tensors over the data minibatch for the (T+1)th step. It should '
'help improve performance using EagerIterator and function. The codepath '
'when enabling this feature is experimental and will be removed once the '
'corresponding performance features are fully supported in TensorFlow.')
def get_synth_input_fn(height, width, num_channels, num_classes,
......@@ -341,6 +349,12 @@ def is_v2_0():
return tf.__version__.startswith('2')
def data_prefetch_with_slack():
"""Use unstable code for perf tuning purposes."""
if not FLAGS.use_synthetic_data:
_monkey_patch_org_create_device_dataset()
def _monkey_patch_org_assert_broadcastable():
"""Monkey-patch `assert_broadcast` op to avoid OOM when enabling XLA."""
def no_op_assert_broadcastable(weights, values):
......@@ -362,3 +376,29 @@ def _undo_monkey_patch_org_assert_broadcastable():
if hasattr(weights_broadcast_ops, 'org_assert_broadcastable'):
weights_broadcast_ops.assert_broadcastable = (
weights_broadcast_ops.org_assert_broadcastable)
# TODO(haoyuzhang): remove this monkey patch when the "prefetch with slack"
# feature is available in tf.data.
def _monkey_patch_org_create_device_dataset():
"""Monkey-patch `_create_device_dataset` method with delayed prefetch."""
import ast # pylint: disable=g-import-not-at-top
import inspect # pylint: disable=g-import-not-at-top
from tensorflow.python.data.ops import multi_device_iterator_ops # pylint: disable=g-import-not-at-top
tf.compat.v1.logging.info(
'Using monkey-patched version of MultiDeviceIterator. It should be '
'removed when the prefetch with slack feature is implemented in tf.data.')
cls_multi_device_iterator = ast.parse(
inspect.getsource(multi_device_iterator_ops.MultiDeviceIterator))
org_create_device_dataset_code = inspect.getsource(
multi_device_iterator_ops.MultiDeviceIterator._create_device_dataset) # pylint: disable=protected-access
code_lines = org_create_device_dataset_code.split('\n')
# Insert in reverse order to avoid line number shift by previous insertions
code_lines.insert(5, ' ds = ds.apply(sleep_ops.sleep(11000))') # 11ms
code_lines.insert(2, ' from tensorflow.python.data.experimental.ops import sleep as sleep_ops') # pylint: disable=line-too-long
patched_code = '\n'.join(line[2:] for line in code_lines)
cls_multi_device_iterator.body[0].body[2] = ast.parse(patched_code).body[0]
exec(compile(cls_multi_device_iterator, '<string>', 'exec'), # pylint: disable=exec-used
multi_device_iterator_ops.__dict__)
......@@ -251,6 +251,21 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.batch_size = 256
self._run_and_report_benchmark()
def benchmark_xla_1_gpu_fp16_tweaked(self):
"""Test Keras model with XLA, 1 GPU, fp16, and manual config tuning."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = True
FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16_tweaked')
FLAGS.dtype = 'fp16'
FLAGS.batch_size = 256
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.data_prefetch_with_slack = True
self._run_and_report_benchmark()
def benchmark_xla_1_gpu_fp16_dynamic(self):
"""Test Keras model with XLA, 1 GPU, fp16, and dynamic loss scaling."""
self._setup()
......@@ -313,6 +328,23 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.batch_size = 256
self._run_and_report_benchmark()
def benchmark_graph_xla_1_gpu_fp16_tweaked(self):
"""Test Keras model in legacy graph mode with 1 GPU, fp16, XLA, and manual
config tuning.
"""
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = False
FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir(
'benchmark_graph_xla_1_gpu_fp16_tweaked')
FLAGS.dtype = 'fp16'
FLAGS.batch_size = 256
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
def benchmark_8_gpu(self):
"""Test Keras model with 8 GPUs."""
self._setup()
......@@ -334,6 +366,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_tweaked')
FLAGS.batch_size = 128 * 8 # 8 GPUs
FLAGS.datasets_num_private_threads = 14
FLAGS.data_prefetch_with_slack = True
self._run_and_report_benchmark()
def benchmark_xla_8_gpu(self):
......@@ -371,6 +404,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.data_prefetch_with_slack = True
self._run_and_report_benchmark()
def benchmark_8_gpu_fp16_dynamic_tweaked(self):
......@@ -386,6 +420,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.loss_scale = 'dynamic'
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.data_prefetch_with_slack = True
self._run_and_report_benchmark()
def benchmark_xla_8_gpu_fp16(self):
......@@ -412,7 +447,8 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.tf_gpu_thread_mode = 'gpu_private'
# FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.data_prefetch_with_slack = True
self._run_and_report_benchmark()
def benchmark_xla_8_gpu_fp16_dynamic_tweaked(self):
......@@ -429,6 +465,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.loss_scale = 'dynamic'
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.data_prefetch_with_slack = True
self._run_and_report_benchmark()
def benchmark_xla_8_gpu_fp16_tensorboard_tweaked(self):
......@@ -444,6 +481,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
'benchmark_xla_8_gpu_fp16_tensorboard_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.data_prefetch_with_slack = True
FLAGS.enable_tensorboard = True
self._run_and_report_benchmark()
......@@ -636,6 +674,7 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark):
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_tweaked')
FLAGS.batch_size = 256 * 8
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.data_prefetch_with_slack = True
self._run_and_report_benchmark()
def benchmark_graph_8_gpu(self):
......
......@@ -107,6 +107,8 @@ def run(flags_obj):
# Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode:
keras_common.set_gpu_thread_mode_and_count(flags_obj)
if flags_obj.data_prefetch_with_slack:
keras_common.data_prefetch_with_slack()
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'float16':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment