Commit 80f8c46e authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Run BERT benchmarks longer to accommodate larger steps_per_loop value

Run more steps or run with smaller steps_per_loop value so each benchmark have multiple batches (loops) for calculating exp_per_seconds

PiperOrigin-RevId: 296129781
parent 1f61912a
...@@ -81,7 +81,7 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -81,7 +81,7 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
distribution_strategy='mirrored' if use_ds else 'off', distribution_strategy='mirrored' if use_ds else 'off',
num_gpus=self.num_gpus) num_gpus=self.num_gpus)
steps_per_loop = 100 steps_per_loop = 50
max_seq_length = input_meta_data['max_seq_length'] max_seq_length = input_meta_data['max_seq_length']
train_input_fn = run_classifier.get_dataset_fn( train_input_fn = run_classifier.get_dataset_fn(
...@@ -132,7 +132,7 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase): ...@@ -132,7 +132,7 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
# Since we only care about performance metrics, we limit # Since we only care about performance metrics, we limit
# the number of training steps and epochs to prevent unnecessarily # the number of training steps and epochs to prevent unnecessarily
# long tests. # long tests.
self.num_steps_per_epoch = 110 self.num_steps_per_epoch = 100
self.num_epochs = 1 self.num_epochs = 1
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
......
...@@ -42,6 +42,7 @@ SQUAD_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_train.tf_record' ...@@ -42,6 +42,7 @@ SQUAD_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_train.tf_record'
SQUAD_PREDICT_FILE = 'gs://tf-perfzero-data/bert/squad/dev-v1.1.json' SQUAD_PREDICT_FILE = 'gs://tf-perfzero-data/bert/squad/dev-v1.1.json'
SQUAD_VOCAB_FILE = 'gs://tf-perfzero-data/bert/squad/vocab.txt' SQUAD_VOCAB_FILE = 'gs://tf-perfzero-data/bert/squad/vocab.txt'
SQUAD_MEDIUM_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_medium_meta_data' SQUAD_MEDIUM_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_medium_meta_data'
SQUAD_LONG_INPUT_META_DATA_PATH = '/placer/prod/home/tensorflow-performance-data/datasets/bert/squad/squad_long_meta_data'
SQUAD_FULL_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_full_meta_data' SQUAD_FULL_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_full_meta_data'
MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_config.json' MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_config.json'
# pylint: enable=line-too-long # pylint: enable=line-too-long
...@@ -100,8 +101,6 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -100,8 +101,6 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
num_gpus=self.num_gpus, num_gpus=self.num_gpus,
datasets_num_private_threads=FLAGS.datasets_num_private_threads) datasets_num_private_threads=FLAGS.datasets_num_private_threads)
@flagsaver.flagsaver @flagsaver.flagsaver
def _train_squad(self, use_ds=True, run_eagerly=False): def _train_squad(self, use_ds=True, run_eagerly=False):
"""Runs BERT SQuAD training.""" """Runs BERT SQuAD training."""
...@@ -152,7 +151,6 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -152,7 +151,6 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
FLAGS.predict_file = SQUAD_PREDICT_FILE FLAGS.predict_file = SQUAD_PREDICT_FILE
FLAGS.vocab_file = SQUAD_VOCAB_FILE FLAGS.vocab_file = SQUAD_VOCAB_FILE
FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
FLAGS.num_train_epochs = 1 FLAGS.num_train_epochs = 1
FLAGS.steps_per_loop = 100 FLAGS.steps_per_loop = 100
...@@ -162,6 +160,10 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -162,6 +160,10 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
use_ds=True, use_ds=True,
run_eagerly=False): run_eagerly=False):
"""Runs the benchmark and reports various metrics.""" """Runs the benchmark and reports various metrics."""
if FLAGS.train_batch_size <= 4:
FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH
else:
FLAGS.input_meta_data_path = SQUAD_LONG_INPUT_META_DATA_PATH
start_time_sec = time.time() start_time_sec = time.time()
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly) self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly)
wall_time_sec = time.time() - start_time_sec wall_time_sec = time.time() - start_time_sec
...@@ -578,7 +580,7 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase): ...@@ -578,7 +580,7 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
FLAGS.predict_file = SQUAD_PREDICT_FILE FLAGS.predict_file = SQUAD_PREDICT_FILE
FLAGS.vocab_file = SQUAD_VOCAB_FILE FLAGS.vocab_file = SQUAD_VOCAB_FILE
FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
FLAGS.num_train_epochs = 1 FLAGS.num_train_epochs = 1
FLAGS.steps_per_loop = 100 FLAGS.steps_per_loop = 100
...@@ -588,6 +590,10 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase): ...@@ -588,6 +590,10 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
use_ds=True, use_ds=True,
run_eagerly=False): run_eagerly=False):
"""Runs the benchmark and reports various metrics.""" """Runs the benchmark and reports various metrics."""
if FLAGS.train_batch_size <= 4 * 8:
FLAGS.input_meta_data_path = SQUAD_LONG_INPUT_META_DATA_PATH
else:
FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
start_time_sec = time.time() start_time_sec = time.time()
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly) self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly)
wall_time_sec = time.time() - start_time_sec wall_time_sec = time.time() - start_time_sec
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment