Commit 956f295f authored by Ayush Dubey's avatar Ayush Dubey Committed by A. Unique TensorFlower
Browse files

Set `datasets_num_private_threads` flag for multi-worker tuned benchmarks.

PiperOrigin-RevId: 268836696
parent cbf29854
......@@ -990,6 +990,7 @@ class Resnet50MultiWorkerKerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'multi_worker_mirrored'
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 32
FLAGS.model_dir = self._get_model_dir(
'benchmark_{}_8_gpu_{}_worker_fp16_{}_tweaked'.format(
'eager' if eager else 'graph', num_workers, all_reduce_alg))
......@@ -1068,6 +1069,7 @@ class Resnet50MultiWorkerKerasBenchmark(Resnet50KerasBenchmarkBase):
FLAGS.distribution_strategy = 'multi_worker_mirrored'
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 32
FLAGS.model_dir = self._get_model_dir(
'benchmark_{}_8_gpu_{}_worker_fp16_{}_tweaked'.format(
'eager' if eager else 'graph', num_workers, all_reduce_alg))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment