Commit 3af667fd authored by Haoyu Zhang's avatar Haoyu Zhang Committed by A. Unique TensorFlower
Browse files

Remove the data_delay_prefetch flag and monkey patch.

PiperOrigin-RevId: 280465279
parent 63d754ec
...@@ -338,13 +338,6 @@ def define_keras_flags(dynamic_loss_scale=True): ...@@ -338,13 +338,6 @@ def define_keras_flags(dynamic_loss_scale=True):
'triggers the profiler to process 3 steps, starting from the 2nd step. ' 'triggers the profiler to process 3 steps, starting from the 2nd step. '
'Note that profiler has a non-trivial performance overhead, and the ' 'Note that profiler has a non-trivial performance overhead, and the '
'output file can be gigantic if profiling many steps.') 'output file can be gigantic if profiling many steps.')
flags.DEFINE_boolean(
name='data_delay_prefetch', default=False,
help='Add a small delay in tf.data prefetch to prioritize memory copy of '
'other tensors over the data minibatch for the (T+1)th step. It should '
'help improve performance using EagerIterator and function. The codepath '
'when enabling this feature is experimental and will be removed once the '
'corresponding performance features are fully supported in TensorFlow.')
flags.DEFINE_boolean( flags.DEFINE_boolean(
name='batchnorm_spatial_persistent', default=True, name='batchnorm_spatial_persistent', default=True,
help='Enable the spacial persistent mode for CuDNN batch norm kernel.') help='Enable the spacial persistent mode for CuDNN batch norm kernel.')
...@@ -413,12 +406,6 @@ def get_synth_input_fn(height, width, num_channels, num_classes, ...@@ -413,12 +406,6 @@ def get_synth_input_fn(height, width, num_channels, num_classes,
return input_fn return input_fn
def data_delay_prefetch():
"""Use unstable code for perf tuning purposes."""
if not FLAGS.use_synthetic_data:
_monkey_patch_org_create_device_dataset()
def set_cudnn_batchnorm_mode(): def set_cudnn_batchnorm_mode():
"""Set CuDNN batchnorm mode for better performance. """Set CuDNN batchnorm mode for better performance.
...@@ -429,29 +416,3 @@ def set_cudnn_batchnorm_mode(): ...@@ -429,29 +416,3 @@ def set_cudnn_batchnorm_mode():
os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1' os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'
else: else:
os.environ.pop('TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT', None) os.environ.pop('TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT', None)
# TODO(haoyuzhang): remove this monkey patch when the "prefetch with slack"
# feature is available in tf.data.
def _monkey_patch_org_create_device_dataset():
"""Monkey-patch `_create_device_dataset` method with delayed prefetch."""
import ast # pylint: disable=g-import-not-at-top
import inspect # pylint: disable=g-import-not-at-top
from tensorflow.python.data.ops import multi_device_iterator_ops # pylint: disable=g-import-not-at-top
tf.compat.v1.logging.info(
'Using monkey-patched version of MultiDeviceIterator. It should be '
'removed when the prefetch with slack feature is implemented in tf.data.')
cls_multi_device_iterator = ast.parse(
inspect.getsource(multi_device_iterator_ops.MultiDeviceIterator))
org_create_device_dataset_code = inspect.getsource(
multi_device_iterator_ops.MultiDeviceIterator._create_device_dataset) # pylint: disable=protected-access
code_lines = org_create_device_dataset_code.split('\n')
# Insert in reverse order to avoid line number shift by previous insertions
code_lines.insert(5, ' ds = ds.apply(sleep_ops.sleep(11000))') # 11ms
code_lines.insert(2, ' from tensorflow.python.data.experimental.ops import sleep as sleep_ops') # pylint: disable=line-too-long
patched_code = '\n'.join(line[2:] for line in code_lines)
cls_multi_device_iterator.body[0].body[2] = ast.parse(patched_code).body[0]
exec(compile(cls_multi_device_iterator, '<string>', 'exec'), # pylint: disable=exec-used
multi_device_iterator_ops.__dict__)
...@@ -55,8 +55,6 @@ def run(flags_obj): ...@@ -55,8 +55,6 @@ def run(flags_obj):
# Execute flag override logic for better model performance # Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode: if flags_obj.tf_gpu_thread_mode:
common.set_gpu_thread_mode_and_count(flags_obj) common.set_gpu_thread_mode_and_count(flags_obj)
if flags_obj.data_delay_prefetch:
common.data_delay_prefetch()
common.set_cudnn_batchnorm_mode() common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment