Commit fbdbe12b authored by saberkun's avatar saberkun Committed by Toby Boyd
Browse files

Merged commit includes the following changes: (#6952)

251325964  by hongkuny<hongkuny@google.com>:

    Improve flags

--
250942274  by tobyboyd<tobyboyd@google.com>:

    Internal change

PiperOrigin-RevId: 251325964
parent e59ad48f
...@@ -168,7 +168,7 @@ class BertBenchmarkBase(tf.test.Benchmark): ...@@ -168,7 +168,7 @@ class BertBenchmarkBase(tf.test.Benchmark):
custom_callbacks=callbacks) custom_callbacks=callbacks)
class BertClassifyBenchmark(BertBenchmarkBase): class BertClassifyBenchmarkReal(BertBenchmarkBase):
"""Short benchmark performance tests for BERT model. """Short benchmark performance tests for BERT model.
Tests BERT classification performance in different GPU configurations. Tests BERT classification performance in different GPU configurations.
...@@ -187,7 +187,7 @@ class BertClassifyBenchmark(BertBenchmarkBase): ...@@ -187,7 +187,7 @@ class BertClassifyBenchmark(BertBenchmarkBase):
self.num_steps_per_epoch = 110 self.num_steps_per_epoch = 110
self.num_epochs = 1 self.num_epochs = 1
super(BertClassifyBenchmark, self).__init__(output_dir=output_dir) super(BertClassifyBenchmarkReal, self).__init__(output_dir=output_dir)
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
training_summary_path, training_summary_path,
...@@ -205,7 +205,7 @@ class BertClassifyBenchmark(BertBenchmarkBase): ...@@ -205,7 +205,7 @@ class BertClassifyBenchmark(BertBenchmarkBase):
# Since we do not load from any pretrained checkpoints, we ignore all # Since we do not load from any pretrained checkpoints, we ignore all
# accuracy metrics. # accuracy metrics.
summary.pop('eval_metrics', None) summary.pop('eval_metrics', None)
super(BertClassifyBenchmark, self)._report_benchmark( super(BertClassifyBenchmarkReal, self)._report_benchmark(
stats=summary, stats=summary,
wall_time_sec=wall_time_sec, wall_time_sec=wall_time_sec,
min_accuracy=min_accuracy, min_accuracy=min_accuracy,
...@@ -227,7 +227,7 @@ class BertClassifyBenchmark(BertBenchmarkBase): ...@@ -227,7 +227,7 @@ class BertClassifyBenchmark(BertBenchmarkBase):
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt') summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path) self._run_and_report_benchmark(summary_path)
def benchmark_2_gpu_mprc(self): def benchmark_2_gpu_mrpc(self):
"""Test BERT model performance with 2 GPUs.""" """Test BERT model performance with 2 GPUs."""
self._setup() self._setup()
......
...@@ -29,18 +29,19 @@ from official.bert import squad_lib ...@@ -29,18 +29,19 @@ from official.bert import squad_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
# BERT classification specific flags.
flags.DEFINE_enum( flags.DEFINE_enum(
"fine_tuning_task_type", "classification", ["classification", "squad"], "fine_tuning_task_type", "classification", ["classification", "squad"],
"The name of the BERT fine tuning task for which data " "The name of the BERT fine tuning task for which data "
"will be generated..") "will be generated..")
# BERT classification specific flags.
flags.DEFINE_string( flags.DEFINE_string(
"input_data_dir", None, "input_data_dir", None,
"The input data dir. Should contain the .tsv files (or other data files) " "The input data dir. Should contain the .tsv files (or other data files) "
"for the task.") "for the task.")
flags.DEFINE_string("classification_task_name", None, flags.DEFINE_enum("classification_task_name", "mnli",
["cola", "mnli", "mrpc", "xnli"],
"The name of the task to train BERT classifier.") "The name of the task to train BERT classifier.")
# BERT Squad task specific flags. # BERT Squad task specific flags.
...@@ -58,6 +59,10 @@ flags.DEFINE_integer( ...@@ -58,6 +59,10 @@ flags.DEFINE_integer(
"The maximum number of tokens for the question. Questions longer than " "The maximum number of tokens for the question. Questions longer than "
"this will be truncated to this length.") "this will be truncated to this length.")
flags.DEFINE_bool(
"version_2_with_negative", False,
"If true, the SQuAD examples contain some that do not have an answer.")
# Shared flags across BERT fine-tuning tasks. # Shared flags across BERT fine-tuning tasks.
flags.DEFINE_string("vocab_file", None, flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.") "The vocabulary file that the BERT model was trained on.")
...@@ -88,10 +93,6 @@ flags.DEFINE_integer( ...@@ -88,10 +93,6 @@ flags.DEFINE_integer(
"Sequences longer than this will be truncated, and sequences shorter " "Sequences longer than this will be truncated, and sequences shorter "
"than this will be padded.") "than this will be padded.")
flags.DEFINE_bool(
"version_2_with_negative", False,
"If true, the SQuAD examples contain some that do not have an answer.")
def generate_classifier_dataset(): def generate_classifier_dataset():
"""Generates classifier dataset and returns input meta data.""" """Generates classifier dataset and returns input meta data."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment