Unverified Commit 7af3bd91 authored by Hongjun Choi's avatar Hongjun Choi Committed by GitHub
Browse files

Merged commit includes the following changes: (#6898)

250347237  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Fix linting errors in BERT benchmark test.

--
250326131  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Internal change

250315593  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Internal change

250303528  by haoyuzhang<haoyuzhang@google.com>:

    Add method docstring to fix lint error.

--

PiperOrigin-RevId: 250347237
parent 8b52cd23
......@@ -76,12 +76,7 @@ class BertBenchmarkBase(tf.test.Benchmark):
'value': stats['train_loss'],
}]
if 'train_metrics' in stats:
metrics.append({
'name': 'train_accuracy',
'value': stats['train_metrics'],
})
if 'eval_metric' in stats:
if 'eval_metrics' in stats:
metrics.append({
'name': 'eval_accuracy',
'value': stats['eval_metrics'],
......@@ -92,9 +87,58 @@ class BertBenchmarkBase(tf.test.Benchmark):
wall_time=wall_time_sec,
metrics=metrics)
@flagsaver.flagsaver
def _run_bert_classifier(self):
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
strategy = tf.distribute.MirroredStrategy()
run_classifier.run_bert(strategy, input_meta_data)
class BertBenchmarkPerformanceTest(BertBenchmarkBase):
"""Short benchmark performance tests for BERT model."""
def __init__(self, output_dir=None, **kwargs):
self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH
self.bert_config_file = MODEL_CONFIG_FILE_PATH
self.input_meta_data_path = CLASSIFIER_INPUT_META_DATA_PATH
super(BertBenchmarkPerformanceTest, self).__init__(output_dir=output_dir)
def _run_and_report_benchmark(self, training_summary_path):
"""Starts BERT performance benchmark test."""
start_time_sec = time.time()
self._run_bert_classifier()
wall_time_sec = time.time() - start_time_sec
with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
summary = json.loads(reader.read().decode('utf-8'))
# Since we do not load from any pretrained checkpoints, we ignore all
# accuracy metrics.
summary.pop('eval_metrics', None)
super(BertBenchmarkPerformanceTest, self)._report_benchmark(
stats=summary, wall_time_sec=wall_time_sec)
def benchmark_8_gpu(self):
"""Test BERT model performance with 8 GPUs."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path)
class BertBenchmarkAccuracyTest(BertBenchmarkBase):
"""Short benchmark tests for BERT model."""
"""Short benchmark test for BERT model that tests accuracy metrics."""
def __init__(self, output_dir=None, **kwargs):
self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
......@@ -105,15 +149,9 @@ class BertBenchmarkAccuracyTest(BertBenchmarkBase):
super(BertBenchmarkAccuracyTest, self).__init__(output_dir=output_dir)
@flagsaver.flagsaver
def _run_bert_classifier(self):
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
strategy = tf.distribute.MirroredStrategy()
run_classifier.run_bert(strategy, input_meta_data)
def _run_and_report_benchmark(self, training_summary_path):
"""Starts BERT accuracy benchmark test."""
start_time_sec = time.time()
self._run_bert_classifier()
wall_time_sec = time.time() - start_time_sec
......@@ -125,6 +163,8 @@ class BertBenchmarkAccuracyTest(BertBenchmarkBase):
stats=summary, wall_time_sec=wall_time_sec)
def benchmark_8_gpu(self):
"""Run BERT model accuracy test with 8 GPUs."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
FLAGS.train_data_path = self.train_data_path
......
......@@ -166,8 +166,8 @@ def main(_):
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=FLAGS.tpu)
tf.config.experimental_connect_to_host(cluster_resolver.master()) # pylint: disable=line-too-long
tf.contrib.distribute.initialize_tpu_system(cluster_resolver)
strategy = tf.contrib.distribute.TPUStrategy(
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.experimental.TPUStrategy(
cluster_resolver, steps_per_run=FLAGS.steps_per_run)
elif FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment