Commit fbdbe12b authored by saberkun's avatar saberkun Committed by Toby Boyd
Browse files

Merged commit includes the following changes: (#6952)

251325964  by hongkuny<hongkuny@google.com>:

    Improve flags

--
250942274  by tobyboyd<tobyboyd@google.com>:

    Internal change

PiperOrigin-RevId: 251325964
parent e59ad48f
......@@ -168,7 +168,7 @@ class BertBenchmarkBase(tf.test.Benchmark):
custom_callbacks=callbacks)
class BertClassifyBenchmark(BertBenchmarkBase):
class BertClassifyBenchmarkReal(BertBenchmarkBase):
"""Short benchmark performance tests for BERT model.
Tests BERT classification performance in different GPU configurations.
......@@ -187,7 +187,7 @@ class BertClassifyBenchmark(BertBenchmarkBase):
self.num_steps_per_epoch = 110
self.num_epochs = 1
super(BertClassifyBenchmark, self).__init__(output_dir=output_dir)
super(BertClassifyBenchmarkReal, self).__init__(output_dir=output_dir)
def _run_and_report_benchmark(self,
training_summary_path,
......@@ -205,7 +205,7 @@ class BertClassifyBenchmark(BertBenchmarkBase):
# Since we do not load from any pretrained checkpoints, we ignore all
# accuracy metrics.
summary.pop('eval_metrics', None)
super(BertClassifyBenchmark, self)._report_benchmark(
super(BertClassifyBenchmarkReal, self)._report_benchmark(
stats=summary,
wall_time_sec=wall_time_sec,
min_accuracy=min_accuracy,
......@@ -227,7 +227,7 @@ class BertClassifyBenchmark(BertBenchmarkBase):
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_2_gpu_mprc(self):
def benchmark_2_gpu_mrpc(self):
"""Test BERT model performance with 2 GPUs."""
self._setup()
......
......@@ -29,18 +29,19 @@ from official.bert import squad_lib
FLAGS = flags.FLAGS
# BERT classification specific flags.
flags.DEFINE_enum(
"fine_tuning_task_type", "classification", ["classification", "squad"],
"The name of the BERT fine tuning task for which data "
"will be generated..")
# BERT classification specific flags.
flags.DEFINE_string(
"input_data_dir", None,
"The input data dir. Should contain the .tsv files (or other data files) "
"for the task.")
flags.DEFINE_string("classification_task_name", None,
flags.DEFINE_enum("classification_task_name", "mnli",
["cola", "mnli", "mrpc", "xnli"],
"The name of the task to train BERT classifier.")
# BERT Squad task specific flags.
......@@ -58,6 +59,10 @@ flags.DEFINE_integer(
"The maximum number of tokens for the question. Questions longer than "
"this will be truncated to this length.")
flags.DEFINE_bool(
"version_2_with_negative", False,
"If true, the SQuAD examples contain some that do not have an answer.")
# Shared flags across BERT fine-tuning tasks.
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
......@@ -88,10 +93,6 @@ flags.DEFINE_integer(
"Sequences longer than this will be truncated, and sequences shorter "
"than this will be padded.")
flags.DEFINE_bool(
"version_2_with_negative", False,
"If true, the SQuAD examples contain some that do not have an answer.")
def generate_classifier_dataset():
"""Generates classifier dataset and returns input meta data."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment