"git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "7ef77d92110006e2cfb2b2e5d1eea2062994146f"
Commit 3097fd2a authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Enable setting GPU private thread in Resnet CTL model

PiperOrigin-RevId: 310236258
parent 34dd50a3
...@@ -113,8 +113,14 @@ def run(flags_obj): ...@@ -113,8 +113,14 @@ def run(flags_obj):
enable_xla=flags_obj.enable_xla) enable_xla=flags_obj.enable_xla)
performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj)) performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))
# This only affects GPU. if tf.config.list_physical_devices('GPU'):
common.set_cudnn_batchnorm_mode() if flags_obj.tf_gpu_thread_mode:
keras_utils.set_gpu_thread_mode_and_count(
per_gpu_thread_count=flags_obj.per_gpu_thread_count,
gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
num_gpus=flags_obj.num_gpus,
datasets_num_private_threads=flags_obj.datasets_num_private_threads)
common.set_cudnn_batchnorm_mode()
# TODO(anj-s): Set data_format without using Keras. # TODO(anj-s): Set data_format without using Keras.
data_format = flags_obj.data_format data_format = flags_obj.data_format
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment