Commit 3097fd2a authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Enable setting GPU private thread in Resnet CTL model

PiperOrigin-RevId: 310236258
parent 34dd50a3
......@@ -113,8 +113,14 @@ def run(flags_obj):
enable_xla=flags_obj.enable_xla)
performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))
# This only affects GPU.
common.set_cudnn_batchnorm_mode()
if tf.config.list_physical_devices('GPU'):
if flags_obj.tf_gpu_thread_mode:
keras_utils.set_gpu_thread_mode_and_count(
per_gpu_thread_count=flags_obj.per_gpu_thread_count,
gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
num_gpus=flags_obj.num_gpus,
datasets_num_private_threads=flags_obj.datasets_num_private_threads)
common.set_cudnn_batchnorm_mode()
# TODO(anj-s): Set data_format without using Keras.
data_format = flags_obj.data_format
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment