Commit 03e4fa2d authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 321652782
parent 709a6617
...@@ -29,6 +29,8 @@ from official.nlp.transformer import misc ...@@ -29,6 +29,8 @@ from official.nlp.transformer import misc
from official.nlp.transformer import transformer_main as transformer_main from official.nlp.transformer import transformer_main as transformer_main
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
TPU_DATA_DIR = 'gs://mlcompass-data/transformer'
GPU_DATA_DIR = os.getenv('TMPDIR')
TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official' TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official'
EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014' EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014'
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -40,37 +42,50 @@ class TransformerBenchmark(PerfZeroBenchmark): ...@@ -40,37 +42,50 @@ class TransformerBenchmark(PerfZeroBenchmark):
Code under test for the Transformer Keras models report the same data and Code under test for the Transformer Keras models report the same data and
require the same FLAG setup. require the same FLAG setup.
""" """
def __init__(self, output_dir=None, default_flags=None, root_data_dir=None, def __init__(self, output_dir=None, default_flags=None, root_data_dir=None,
flag_methods=None, tpu=None): flag_methods=None, tpu=None):
root_data_dir = root_data_dir if root_data_dir else '' self._set_data_files()
if default_flags is None:
default_flags = {}
default_flags['data_dir'] = self.train_data_dir
default_flags['vocab_file'] = self.vocab_file
super(TransformerBenchmark, self).__init__(
output_dir=output_dir,
default_flags=default_flags,
flag_methods=flag_methods,
tpu=tpu)
def _set_data_files(self, tpu_run=False):
"""Sets train_data_dir, vocab_file, bleu_source and bleu_ref."""
if tpu_run:
root_data_dir = TPU_DATA_DIR
else:
root_data_dir = GPU_DATA_DIR
self.train_data_dir = os.path.join(root_data_dir, self.train_data_dir = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME) TRANSFORMER_EN2DE_DATA_DIR_NAME)
self.vocab_file = os.path.join(root_data_dir, self.vocab_file = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME, TRANSFORMER_EN2DE_DATA_DIR_NAME,
'vocab.ende.32768') 'vocab.ende.32768')
self.bleu_source = os.path.join(root_data_dir, self.bleu_source = os.path.join(root_data_dir,
EN2DE_2014_BLEU_DATA_DIR_NAME, EN2DE_2014_BLEU_DATA_DIR_NAME,
'newstest2014.en') 'newstest2014.en')
self.bleu_ref = os.path.join(root_data_dir, self.bleu_ref = os.path.join(root_data_dir,
EN2DE_2014_BLEU_DATA_DIR_NAME, EN2DE_2014_BLEU_DATA_DIR_NAME,
'newstest2014.de') 'newstest2014.de')
if default_flags is None: def _set_data_file_flags(self):
default_flags = {} """Sets the FLAGS for the data files."""
default_flags['data_dir'] = self.train_data_dir FLAGS.data_dir = self.train_data_dir
default_flags['vocab_file'] = self.vocab_file FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
super(TransformerBenchmark, self).__init__( FLAGS['bleu_source'].value = self.bleu_source
output_dir=output_dir, FLAGS['bleu_ref'].value = self.bleu_ref
default_flags=default_flags,
flag_methods=flag_methods,
tpu=tpu)
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
...@@ -164,12 +179,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -164,12 +179,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
not converge to the 27.3 BLEU (uncased) SOTA. not converge to the 27.3 BLEU (uncased) SOTA.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 2048 FLAGS.batch_size = 2048
FLAGS.train_steps = 1000 FLAGS.train_steps = 1000
...@@ -189,12 +200,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -189,12 +200,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
not converge to the 27.3 BLEU (uncased) SOTA. not converge to the 27.3 BLEU (uncased) SOTA.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 4096 FLAGS.batch_size = 4096
FLAGS.train_steps = 100000 FLAGS.train_steps = 100000
...@@ -215,12 +222,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -215,12 +222,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet. Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 4096*8 FLAGS.batch_size = 4096*8
FLAGS.train_steps = 100000 FLAGS.train_steps = 100000
...@@ -237,12 +240,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -237,12 +240,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet. Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 4096*8 FLAGS.batch_size = 4096*8
FLAGS.train_steps = 100000 FLAGS.train_steps = 100000
...@@ -284,12 +283,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -284,12 +283,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Iterations are not epochs, an iteration is a number of steps between evals. Iterations are not epochs, an iteration is a number of steps between evals.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12 FLAGS.train_steps = 20000 * 12
...@@ -306,12 +301,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -306,12 +301,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Should converge to 28.4 BLEU (uncased). This has not be verified yet."
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.static_batch = True FLAGS.static_batch = True
...@@ -337,13 +328,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -337,13 +328,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
not epochs, an iteration is a number of steps between evals. not epochs, an iteration is a number of steps between evals.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12 FLAGS.train_steps = 20000 * 12
...@@ -360,14 +347,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -360,14 +347,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Should converge to 28.4 BLEU (uncased). This has not be verified yet."
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite' FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12 FLAGS.train_steps = 20000 * 12
...@@ -384,13 +367,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -384,13 +367,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Should converge to 28.4 BLEU (uncased). This has not be verified yet."
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.static_batch = True FLAGS.static_batch = True
...@@ -409,14 +388,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -409,14 +388,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Should converge to 28.4 BLEU (uncased). This has not be verified yet."
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.enable_xla = True FLAGS.enable_xla = True
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.static_batch = True FLAGS.static_batch = True
...@@ -687,22 +662,41 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -687,22 +662,41 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
root_data_dir=root_data_dir, batch_per_gpu=3072, root_data_dir=root_data_dir, batch_per_gpu=3072,
tpu=tpu) tpu=tpu)
def benchmark_2x2_tpu(self): def _set_df_common(self):
"""Port of former snaggletooth transformer_big model on 2x2.""" self._set_data_files(tpu_run=True)
self._setup() FLAGS.data_dir = self.train_data_dir
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu') FLAGS.vocab_file = self.vocab_file
FLAGS.distribution_strategy = 'tpu'
FLAGS.padded_decode = True
FLAGS.train_steps = 300 FLAGS.train_steps = 300
FLAGS.log_steps = 150 FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150 FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True FLAGS.static_batch = True
FLAGS.use_ctl = True FLAGS.use_ctl = True
FLAGS.batch_size = 6144 FLAGS.enable_checkpointing = False
FLAGS.max_length = 64 FLAGS.max_length = 64
FLAGS.decode_batch_size = 32 FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97 FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False def benchmark_2x2_tpu(self):
"""Port of former snaggletooth transformer_big model on 2x2."""
self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
FLAGS.batch_size = 6144
self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
@owner_utils.Owner('tf-graph-compiler')
def benchmark_2x2_tpu_mlir(self):
"""Run transformer_big model on 2x2 with the MLIR Bridge enabled."""
self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_mlir')
FLAGS.batch_size = 6144
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark( self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size, total_batch_size=FLAGS.batch_size,
...@@ -711,19 +705,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -711,19 +705,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
def benchmark_4x4_tpu(self): def benchmark_4x4_tpu(self):
"""Port of former GCP transformer_big model on 4x4.""" """Port of former GCP transformer_big model on 4x4."""
self._setup() self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu') FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu')
FLAGS.train_steps = 300
FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True
FLAGS.use_ctl = True
FLAGS.batch_size = 24576 FLAGS.batch_size = 24576
FLAGS.max_length = 64
FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False
self._run_and_report_benchmark( self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size, total_batch_size=FLAGS.batch_size,
...@@ -733,19 +717,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -733,19 +717,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
def benchmark_4x4_tpu_mlir(self): def benchmark_4x4_tpu_mlir(self):
"""Run transformer_big model on 4x4 with the MLIR Bridge enabled.""" """Run transformer_big model on 4x4 with the MLIR Bridge enabled."""
self._setup() self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu') self._set_df_common()
FLAGS.train_steps = 300 FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_mlir')
FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True
FLAGS.use_ctl = True
FLAGS.batch_size = 24576 FLAGS.batch_size = 24576
FLAGS.max_length = 64
FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False
tf.config.experimental.enable_mlir_bridge() tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark( self._run_and_report_benchmark(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment