"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "48b6992b71acbda35ec63ba811ed3b8ed9f1f90c"
Commit 2be9ba5b authored by guptapriya's avatar guptapriya
Browse files

Turn dist strat off for 1 GPU benchmarks

parent 1d16f473
...@@ -133,6 +133,7 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -133,6 +133,7 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
""" """
self._setup() self._setup()
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.distribution_strategy = off
FLAGS.data_dir = self.train_data_dir FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check. # Sets values directly to avoid validation check.
...@@ -158,6 +159,7 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -158,6 +159,7 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
""" """
self._setup() self._setup()
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.distribution_strategy = off
FLAGS.data_dir = self.train_data_dir FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check. # Sets values directly to avoid validation check.
...@@ -315,6 +317,7 @@ class TransformerKerasBenchmark(TransformerBenchmark): ...@@ -315,6 +317,7 @@ class TransformerKerasBenchmark(TransformerBenchmark):
"""Benchmark 1 gpu.""" """Benchmark 1 gpu."""
self._setup() self._setup()
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.distribution_strategy = off
FLAGS.batch_size = self.batch_per_gpu FLAGS.batch_size = self.batch_per_gpu
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu') FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size, self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
...@@ -324,6 +327,7 @@ class TransformerKerasBenchmark(TransformerBenchmark): ...@@ -324,6 +327,7 @@ class TransformerKerasBenchmark(TransformerBenchmark):
"""Benchmark 1 gpu.""" """Benchmark 1 gpu."""
self._setup() self._setup()
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.distribution_strategy = off
FLAGS.batch_size = self.batch_per_gpu FLAGS.batch_size = self.batch_per_gpu
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_static_batch') FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_static_batch')
# TODO(guptapriya): Add max_length # TODO(guptapriya): Add max_length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment