Commit 956f295f authored by Ayush Dubey's avatar Ayush Dubey Committed by A. Unique TensorFlower
Browse files

Set `datasets_num_private_threads` flag for multi-worker tuned benchmarks.

PiperOrigin-RevId: 268836696
parent cbf29854
...@@ -990,6 +990,7 @@ class Resnet50MultiWorkerKerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -990,6 +990,7 @@ class Resnet50MultiWorkerKerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'multi_worker_mirrored' FLAGS.distribution_strategy = 'multi_worker_mirrored'
FLAGS.use_tensor_lr = True FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 32
FLAGS.model_dir = self._get_model_dir( FLAGS.model_dir = self._get_model_dir(
'benchmark_{}_8_gpu_{}_worker_fp16_{}_tweaked'.format( 'benchmark_{}_8_gpu_{}_worker_fp16_{}_tweaked'.format(
'eager' if eager else 'graph', num_workers, all_reduce_alg)) 'eager' if eager else 'graph', num_workers, all_reduce_alg))
...@@ -1068,6 +1069,7 @@ class Resnet50MultiWorkerKerasBenchmark(Resnet50KerasBenchmarkBase): ...@@ -1068,6 +1069,7 @@ class Resnet50MultiWorkerKerasBenchmark(Resnet50KerasBenchmarkBase):
FLAGS.distribution_strategy = 'multi_worker_mirrored' FLAGS.distribution_strategy = 'multi_worker_mirrored'
FLAGS.use_tensor_lr = True FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 32
FLAGS.model_dir = self._get_model_dir( FLAGS.model_dir = self._get_model_dir(
'benchmark_{}_8_gpu_{}_worker_fp16_{}_tweaked'.format( 'benchmark_{}_8_gpu_{}_worker_fp16_{}_tweaked'.format(
'eager' if eager else 'graph', num_workers, all_reduce_alg)) 'eager' if eager else 'graph', num_workers, all_reduce_alg))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment