Commit 3bb5dd6c authored by guptapriya's avatar guptapriya
Browse files

fix num_gpus in benchmark

parent b5a69819
...@@ -347,7 +347,7 @@ class TransformerKerasBenchmark(TransformerBenchmark): ...@@ -347,7 +347,7 @@ class TransformerKerasBenchmark(TransformerBenchmark):
def benchmark_8_gpu_static_batch(self): def benchmark_8_gpu_static_batch(self):
"""Benchmark 8 gpu.""" """Benchmark 8 gpu."""
self._setup() self._setup()
FLAGS.num_gpus = 1 FLAGS.num_gpus = 8
FLAGS.batch_size = self.batch_per_gpu * 8 FLAGS.batch_size = self.batch_per_gpu * 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_static_batch') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_static_batch')
# TODO(guptapriya): Add max_length # TODO(guptapriya): Add max_length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment