Commit 74556d99 authored by Tayo Oguntebi's avatar Tayo Oguntebi Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 307683464
parent 338a0fc2
...@@ -42,7 +42,7 @@ class TransformerBenchmark(PerfZeroBenchmark): ...@@ -42,7 +42,7 @@ class TransformerBenchmark(PerfZeroBenchmark):
""" """
def __init__(self, output_dir=None, default_flags=None, root_data_dir=None, def __init__(self, output_dir=None, default_flags=None, root_data_dir=None,
flag_methods=None): flag_methods=None, tpu=None):
root_data_dir = root_data_dir if root_data_dir else '' root_data_dir = root_data_dir if root_data_dir else ''
self.train_data_dir = os.path.join(root_data_dir, self.train_data_dir = os.path.join(root_data_dir,
...@@ -68,7 +68,8 @@ class TransformerBenchmark(PerfZeroBenchmark): ...@@ -68,7 +68,8 @@ class TransformerBenchmark(PerfZeroBenchmark):
super(TransformerBenchmark, self).__init__( super(TransformerBenchmark, self).__init__(
output_dir=output_dir, output_dir=output_dir,
default_flags=default_flags, default_flags=default_flags,
flag_methods=flag_methods) flag_methods=flag_methods,
tpu=tpu)
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
...@@ -428,7 +429,7 @@ class TransformerKerasBenchmark(TransformerBenchmark): ...@@ -428,7 +429,7 @@ class TransformerKerasBenchmark(TransformerBenchmark):
"""Benchmarks for Transformer (Base and Big) using Keras.""" """Benchmarks for Transformer (Base and Big) using Keras."""
def __init__(self, output_dir=None, default_flags=None, def __init__(self, output_dir=None, default_flags=None,
root_data_dir=None, batch_per_gpu=4096): root_data_dir=None, batch_per_gpu=4096, tpu=None):
"""Initialize. """Initialize.
Args: Args:
...@@ -436,6 +437,7 @@ class TransformerKerasBenchmark(TransformerBenchmark): ...@@ -436,6 +437,7 @@ class TransformerKerasBenchmark(TransformerBenchmark):
default_flags: default flags to use for all tests. default_flags: default flags to use for all tests.
root_data_dir: root directory for data, e.g. training. root_data_dir: root directory for data, e.g. training.
batch_per_gpu: batch size to use per gpu. batch_per_gpu: batch size to use per gpu.
tpu: Target TPU to use.
""" """
flag_methods = [misc.define_transformer_flags] flag_methods = [misc.define_transformer_flags]
self.batch_per_gpu = batch_per_gpu self.batch_per_gpu = batch_per_gpu
...@@ -444,7 +446,8 @@ class TransformerKerasBenchmark(TransformerBenchmark): ...@@ -444,7 +446,8 @@ class TransformerKerasBenchmark(TransformerBenchmark):
output_dir=output_dir, output_dir=output_dir,
default_flags=default_flags, default_flags=default_flags,
root_data_dir=root_data_dir, root_data_dir=root_data_dir,
flag_methods=flag_methods) flag_methods=flag_methods,
tpu=tpu)
def benchmark_1_gpu_no_dist_strat(self): def benchmark_1_gpu_no_dist_strat(self):
"""Benchmark 1 gpu without distribution strategy.""" """Benchmark 1 gpu without distribution strategy."""
...@@ -666,7 +669,8 @@ class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -666,7 +669,8 @@ class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark):
class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
"""Transformer based version real data benchmark tests.""" """Transformer based version real data benchmark tests."""
def __init__(self, output_dir=TMP_DIR, root_data_dir=TMP_DIR, **kwargs): def __init__(self, output_dir=TMP_DIR, root_data_dir=TMP_DIR,
tpu=None, **kwargs):
def_flags = {} def_flags = {}
def_flags['param_set'] = 'big' def_flags['param_set'] = 'big'
def_flags['train_steps'] = 50 def_flags['train_steps'] = 50
...@@ -674,7 +678,27 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -674,7 +678,27 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
super(TransformerBigKerasBenchmarkReal, self).__init__( super(TransformerBigKerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags, output_dir=output_dir, default_flags=def_flags,
root_data_dir=root_data_dir, batch_per_gpu=3072) root_data_dir=root_data_dir, batch_per_gpu=3072,
tpu=tpu)
def benchmark_2x2_tpu(self):
"""Port of former snaggletooth transformer_big model on 2x2."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
FLAGS.train_steps = 300
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True
FLAGS.use_ctl = True
FLAGS.batch_size = 6144
FLAGS.max_length = 64
FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False
self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment