"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "04ec6ba2ac7a4e4beee8be9dc15bc1922544ca82"
Commit 80f8c46e authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Run BERT benchmarks longer to accommodate larger steps_per_loop value

Run more steps or run with smaller steps_per_loop value so each benchmark have multiple batches (loops) for calculating exp_per_seconds

PiperOrigin-RevId: 296129781
parent 1f61912a
......@@ -81,7 +81,7 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
distribution_strategy='mirrored' if use_ds else 'off',
num_gpus=self.num_gpus)
steps_per_loop = 100
steps_per_loop = 50
max_seq_length = input_meta_data['max_seq_length']
train_input_fn = run_classifier.get_dataset_fn(
......@@ -132,7 +132,7 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
# Since we only care about performance metrics, we limit
# the number of training steps and epochs to prevent unnecessarily
# long tests.
self.num_steps_per_epoch = 110
self.num_steps_per_epoch = 100
self.num_epochs = 1
@benchmark_wrappers.enable_runtime_flags
......
......@@ -42,6 +42,7 @@ SQUAD_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_train.tf_record'
SQUAD_PREDICT_FILE = 'gs://tf-perfzero-data/bert/squad/dev-v1.1.json'
SQUAD_VOCAB_FILE = 'gs://tf-perfzero-data/bert/squad/vocab.txt'
SQUAD_MEDIUM_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_medium_meta_data'
SQUAD_LONG_INPUT_META_DATA_PATH = '/placer/prod/home/tensorflow-performance-data/datasets/bert/squad/squad_long_meta_data'
SQUAD_FULL_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_full_meta_data'
MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_config.json'
# pylint: enable=line-too-long
......@@ -100,8 +101,6 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
num_gpus=self.num_gpus,
datasets_num_private_threads=FLAGS.datasets_num_private_threads)
@flagsaver.flagsaver
def _train_squad(self, use_ds=True, run_eagerly=False):
"""Runs BERT SQuAD training."""
......@@ -152,7 +151,6 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
FLAGS.predict_file = SQUAD_PREDICT_FILE
FLAGS.vocab_file = SQUAD_VOCAB_FILE
FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
FLAGS.num_train_epochs = 1
FLAGS.steps_per_loop = 100
......@@ -162,6 +160,10 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
use_ds=True,
run_eagerly=False):
"""Runs the benchmark and reports various metrics."""
if FLAGS.train_batch_size <= 4:
FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH
else:
FLAGS.input_meta_data_path = SQUAD_LONG_INPUT_META_DATA_PATH
start_time_sec = time.time()
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly)
wall_time_sec = time.time() - start_time_sec
......@@ -578,7 +580,7 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
FLAGS.predict_file = SQUAD_PREDICT_FILE
FLAGS.vocab_file = SQUAD_VOCAB_FILE
FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH
FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
FLAGS.num_train_epochs = 1
FLAGS.steps_per_loop = 100
......@@ -588,6 +590,10 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
use_ds=True,
run_eagerly=False):
"""Runs the benchmark and reports various metrics."""
if FLAGS.train_batch_size <= 4 * 8:
FLAGS.input_meta_data_path = SQUAD_LONG_INPUT_META_DATA_PATH
else:
FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
start_time_sec = time.time()
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly)
wall_time_sec = time.time() - start_time_sec
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment