Commit 9df6a3d6 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Add squad xlnet accuracy test

PiperOrigin-RevId: 277992916
parent c14f5f4d
...@@ -30,21 +30,23 @@ import tensorflow as tf ...@@ -30,21 +30,23 @@ import tensorflow as tf
from official.benchmark import bert_benchmark_utils as benchmark_utils from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.nlp.xlnet import run_classifier from official.nlp.xlnet import run_classifier
from official.nlp.xlnet import run_squad
# pylint: disable=line-too-long # pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/xlnet/large/xlnet_model-1' PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/xlnet/large/xlnet_model-1'
CLASSIFIER_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/xlnet/imdb/spiece.model.len-512.train.tf_record' CLASSIFIER_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/xlnet/imdb/spiece.model.len-512.train.tf_record'
CLASSIFIER_EVAL_DATA_PATH = 'gs://tf-perfzero-data/xlnet/imdb/spiece.model.len-512.dev.eval.tf_record' CLASSIFIER_EVAL_DATA_PATH = 'gs://tf-perfzero-data/xlnet/imdb/spiece.model.len-512.dev.eval.tf_record'
SQUAD_DATA_PATH = 'gs://tf-perfzero-data/xlnet/squadv2_cased/'
# pylint: enable=line-too-long # pylint: enable=line-too-long
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
class XLNetClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase): class XLNetBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Base class to hold methods common to test classes in the module.""" """Base class to hold methods common to test classes in the module."""
def __init__(self, output_dir=None): def __init__(self, output_dir=None):
super(XLNetClassifyBenchmarkBase, self).__init__(output_dir) super(XLNetBenchmarkBase, self).__init__(output_dir)
self.num_epochs = None self.num_epochs = None
self.num_steps_per_epoch = None self.num_steps_per_epoch = None
...@@ -53,9 +55,14 @@ class XLNetClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -53,9 +55,14 @@ class XLNetClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Starts XLNet classification task.""" """Starts XLNet classification task."""
run_classifier.main(unused_argv=None) run_classifier.main(unused_argv=None)
@flagsaver.flagsaver
def _run_xlnet_squad(self):
"""Starts XLNet classification task."""
run_squad.main(unused_argv=None)
class XLNetClassifyAccuracy(XLNetClassifyBenchmarkBase):
"""Short accuracy test for XLNet model. class XLNetClassifyAccuracy(XLNetBenchmarkBase):
"""Short accuracy test for XLNet classifier model.
Tests XLNet classification task model accuracy. The naming Tests XLNet classification task model accuracy. The naming
convention of below test cases follow convention of below test cases follow
...@@ -93,7 +100,6 @@ class XLNetClassifyAccuracy(XLNetClassifyBenchmarkBase): ...@@ -93,7 +100,6 @@ class XLNetClassifyAccuracy(XLNetClassifyBenchmarkBase):
FLAGS.test_data_size = 25024 FLAGS.test_data_size = 25024
FLAGS.train_batch_size = 16 FLAGS.train_batch_size = 16
FLAGS.seq_len = 512 FLAGS.seq_len = 512
FLAGS.reuse_len = 256
FLAGS.mem_len = 0 FLAGS.mem_len = 0
FLAGS.n_layer = 24 FLAGS.n_layer = 24
FLAGS.d_model = 1024 FLAGS.d_model = 1024
...@@ -126,5 +132,81 @@ class XLNetClassifyAccuracy(XLNetClassifyBenchmarkBase): ...@@ -126,5 +132,81 @@ class XLNetClassifyAccuracy(XLNetClassifyBenchmarkBase):
self._run_and_report_benchmark(summary_path) self._run_and_report_benchmark(summary_path)
class XLNetSquadAccuracy(XLNetBenchmarkBase):
"""Short accuracy test for XLNet squad model.
Tests XLNet squad task model accuracy. The naming
convention of below test cases follow
`benchmark_(number of gpus)_gpu_(dataset type)` format.
"""
def __init__(self, output_dir=None, **kwargs):
self.train_data_path = SQUAD_DATA_PATH
self.predict_file = os.path.join(SQUAD_DATA_PATH, "dev-v2.0.json")
self.test_data_path = os.path.join(SQUAD_DATA_PATH, "12048.eval.tf_record")
self.spiece_model_file = os.path.join(SQUAD_DATA_PATH, "spiece.cased.model")
self.pretrained_checkpoint_path = PRETRAINED_CHECKPOINT_PATH
super(XLNetSquadAccuracy, self).__init__(output_dir=output_dir)
def _run_and_report_benchmark(self,
training_summary_path,
min_accuracy=0.87,
max_accuracy=0.89):
"""Starts XLNet accuracy benchmark test."""
start_time_sec = time.time()
self._run_xlnet_squad()
wall_time_sec = time.time() - start_time_sec
with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
summary = json.loads(reader.read().decode('utf-8'))
super(XLNetSquadAccuracy, self)._report_benchmark(
stats=summary,
wall_time_sec=wall_time_sec,
min_accuracy=min_accuracy,
max_accuracy=max_accuracy)
def _setup(self):
super(XLNetSquadAccuracy, self)._setup()
FLAGS.train_batch_size = 16
FLAGS.seq_len = 512
FLAGS.mem_len = 0
FLAGS.n_layer = 24
FLAGS.d_model = 1024
FLAGS.d_embed = 1024
FLAGS.n_head = 16
FLAGS.d_head = 64
FLAGS.d_inner = 4096
FLAGS.untie_r = True
FLAGS.ff_activation = 'gelu'
FLAGS.strategy_type = 'mirror'
FLAGS.learning_rate = 3e-5
FLAGS.train_steps = 8000
FLAGS.warmup_steps = 1000
FLAGS.iterations = 1000
FLAGS.bi_data = False
FLAGS.init_checkpoint = self.pretrained_checkpoint_path
FLAGS.train_tfrecord_path = self.train_data_path
FLAGS.test_tfrecord_path = self.test_data_path
FLAGS.spiece_model_file = self.spiece_model_file
FLAGS.predict_file = self.predict_file
FLAGS.adam_epsilon=1e-6
FLAGS.lr_layer_decay_rate=0.75
def benchmark_8_gpu_squadv2(self):
"""Run XLNet model squad v2 accuracy test with 8 GPUs."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squadv2')
FLAGS.predict_dir = FLAGS.model_dir
# Sets timer_callback to None as we do not use it now.
self.timer_callback = None
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -270,7 +270,8 @@ def main(unused_argv): ...@@ -270,7 +270,8 @@ def main(unused_argv):
logging.info("finishing reading pickle file...") logging.info("finishing reading pickle file...")
else: else:
sp_model = spm.SentencePieceProcessor() sp_model = spm.SentencePieceProcessor()
sp_model.Load(FLAGS.spiece_model_file) sp_model.LoadFromSerializedProto(
tf.io.gfile.GFile(FLAGS.spiece_model_file, "rb").read())
spm_basename = os.path.basename(FLAGS.spiece_model_file) spm_basename = os.path.basename(FLAGS.spiece_model_file)
eval_features = squad_utils.create_eval_data( eval_features = squad_utils.create_eval_data(
spm_basename, sp_model, eval_examples, FLAGS.max_seq_length, spm_basename, sp_model, eval_examples, FLAGS.max_seq_length,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment