Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9df6a3d6
Commit
9df6a3d6
authored
Nov 01, 2019
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Nov 01, 2019
Browse files
Add squad xlnet accuracy test
PiperOrigin-RevId: 277992916
parent
c14f5f4d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
89 additions
and
6 deletions
+89
-6
official/benchmark/xlnet_benchmark.py
official/benchmark/xlnet_benchmark.py
+87
-5
official/nlp/xlnet/run_squad.py
official/nlp/xlnet/run_squad.py
+2
-1
No files found.
official/benchmark/xlnet_benchmark.py
View file @
9df6a3d6
...
@@ -30,21 +30,23 @@ import tensorflow as tf
...
@@ -30,21 +30,23 @@ import tensorflow as tf
from
official.benchmark
import
bert_benchmark_utils
as
benchmark_utils
from
official.benchmark
import
bert_benchmark_utils
as
benchmark_utils
from
official.nlp.xlnet
import
run_classifier
from
official.nlp.xlnet
import
run_classifier
from
official.nlp.xlnet
import
run_squad
# pylint: disable=line-too-long
# pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH
=
'gs://cloud-tpu-checkpoints/xlnet/large/xlnet_model-1'
PRETRAINED_CHECKPOINT_PATH
=
'gs://cloud-tpu-checkpoints/xlnet/large/xlnet_model-1'
CLASSIFIER_TRAIN_DATA_PATH
=
'gs://tf-perfzero-data/xlnet/imdb/spiece.model.len-512.train.tf_record'
CLASSIFIER_TRAIN_DATA_PATH
=
'gs://tf-perfzero-data/xlnet/imdb/spiece.model.len-512.train.tf_record'
CLASSIFIER_EVAL_DATA_PATH
=
'gs://tf-perfzero-data/xlnet/imdb/spiece.model.len-512.dev.eval.tf_record'
CLASSIFIER_EVAL_DATA_PATH
=
'gs://tf-perfzero-data/xlnet/imdb/spiece.model.len-512.dev.eval.tf_record'
SQUAD_DATA_PATH
=
'gs://tf-perfzero-data/xlnet/squadv2_cased/'
# pylint: enable=line-too-long
# pylint: enable=line-too-long
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
class
XLNet
Classify
BenchmarkBase
(
benchmark_utils
.
BertBenchmarkBase
):
class
XLNetBenchmarkBase
(
benchmark_utils
.
BertBenchmarkBase
):
"""Base class to hold methods common to test classes in the module."""
"""Base class to hold methods common to test classes in the module."""
def
__init__
(
self
,
output_dir
=
None
):
def
__init__
(
self
,
output_dir
=
None
):
super
(
XLNet
Classify
BenchmarkBase
,
self
).
__init__
(
output_dir
)
super
(
XLNetBenchmarkBase
,
self
).
__init__
(
output_dir
)
self
.
num_epochs
=
None
self
.
num_epochs
=
None
self
.
num_steps_per_epoch
=
None
self
.
num_steps_per_epoch
=
None
...
@@ -53,9 +55,14 @@ class XLNetClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
...
@@ -53,9 +55,14 @@ class XLNetClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Starts XLNet classification task."""
"""Starts XLNet classification task."""
run_classifier
.
main
(
unused_argv
=
None
)
run_classifier
.
main
(
unused_argv
=
None
)
@
flagsaver
.
flagsaver
def
_run_xlnet_squad
(
self
):
"""Starts XLNet classification task."""
run_squad
.
main
(
unused_argv
=
None
)
class
XLNetClassifyAccuracy
(
XLNetClassifyBenchmarkBase
):
"""Short accuracy test for XLNet model.
class
XLNetClassifyAccuracy
(
XLNetBenchmarkBase
):
"""Short accuracy test for XLNet classifier model.
Tests XLNet classification task model accuracy. The naming
Tests XLNet classification task model accuracy. The naming
convention of below test cases follow
convention of below test cases follow
...
@@ -93,7 +100,6 @@ class XLNetClassifyAccuracy(XLNetClassifyBenchmarkBase):
...
@@ -93,7 +100,6 @@ class XLNetClassifyAccuracy(XLNetClassifyBenchmarkBase):
FLAGS
.
test_data_size
=
25024
FLAGS
.
test_data_size
=
25024
FLAGS
.
train_batch_size
=
16
FLAGS
.
train_batch_size
=
16
FLAGS
.
seq_len
=
512
FLAGS
.
seq_len
=
512
FLAGS
.
reuse_len
=
256
FLAGS
.
mem_len
=
0
FLAGS
.
mem_len
=
0
FLAGS
.
n_layer
=
24
FLAGS
.
n_layer
=
24
FLAGS
.
d_model
=
1024
FLAGS
.
d_model
=
1024
...
@@ -126,5 +132,81 @@ class XLNetClassifyAccuracy(XLNetClassifyBenchmarkBase):
...
@@ -126,5 +132,81 @@ class XLNetClassifyAccuracy(XLNetClassifyBenchmarkBase):
self
.
_run_and_report_benchmark
(
summary_path
)
self
.
_run_and_report_benchmark
(
summary_path
)
class
XLNetSquadAccuracy
(
XLNetBenchmarkBase
):
"""Short accuracy test for XLNet squad model.
Tests XLNet squad task model accuracy. The naming
convention of below test cases follow
`benchmark_(number of gpus)_gpu_(dataset type)` format.
"""
def
__init__
(
self
,
output_dir
=
None
,
**
kwargs
):
self
.
train_data_path
=
SQUAD_DATA_PATH
self
.
predict_file
=
os
.
path
.
join
(
SQUAD_DATA_PATH
,
"dev-v2.0.json"
)
self
.
test_data_path
=
os
.
path
.
join
(
SQUAD_DATA_PATH
,
"12048.eval.tf_record"
)
self
.
spiece_model_file
=
os
.
path
.
join
(
SQUAD_DATA_PATH
,
"spiece.cased.model"
)
self
.
pretrained_checkpoint_path
=
PRETRAINED_CHECKPOINT_PATH
super
(
XLNetSquadAccuracy
,
self
).
__init__
(
output_dir
=
output_dir
)
def
_run_and_report_benchmark
(
self
,
training_summary_path
,
min_accuracy
=
0.87
,
max_accuracy
=
0.89
):
"""Starts XLNet accuracy benchmark test."""
start_time_sec
=
time
.
time
()
self
.
_run_xlnet_squad
()
wall_time_sec
=
time
.
time
()
-
start_time_sec
with
tf
.
io
.
gfile
.
GFile
(
training_summary_path
,
'rb'
)
as
reader
:
summary
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
super
(
XLNetSquadAccuracy
,
self
).
_report_benchmark
(
stats
=
summary
,
wall_time_sec
=
wall_time_sec
,
min_accuracy
=
min_accuracy
,
max_accuracy
=
max_accuracy
)
def
_setup
(
self
):
super
(
XLNetSquadAccuracy
,
self
).
_setup
()
FLAGS
.
train_batch_size
=
16
FLAGS
.
seq_len
=
512
FLAGS
.
mem_len
=
0
FLAGS
.
n_layer
=
24
FLAGS
.
d_model
=
1024
FLAGS
.
d_embed
=
1024
FLAGS
.
n_head
=
16
FLAGS
.
d_head
=
64
FLAGS
.
d_inner
=
4096
FLAGS
.
untie_r
=
True
FLAGS
.
ff_activation
=
'gelu'
FLAGS
.
strategy_type
=
'mirror'
FLAGS
.
learning_rate
=
3e-5
FLAGS
.
train_steps
=
8000
FLAGS
.
warmup_steps
=
1000
FLAGS
.
iterations
=
1000
FLAGS
.
bi_data
=
False
FLAGS
.
init_checkpoint
=
self
.
pretrained_checkpoint_path
FLAGS
.
train_tfrecord_path
=
self
.
train_data_path
FLAGS
.
test_tfrecord_path
=
self
.
test_data_path
FLAGS
.
spiece_model_file
=
self
.
spiece_model_file
FLAGS
.
predict_file
=
self
.
predict_file
FLAGS
.
adam_epsilon
=
1e-6
FLAGS
.
lr_layer_decay_rate
=
0.75
def
benchmark_8_gpu_squadv2
(
self
):
"""Run XLNet model squad v2 accuracy test with 8 GPUs."""
self
.
_setup
()
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu_squadv2'
)
FLAGS
.
predict_dir
=
FLAGS
.
model_dir
# Sets timer_callback to None as we do not use it now.
self
.
timer_callback
=
None
summary_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'summaries/training_summary.txt'
)
self
.
_run_and_report_benchmark
(
summary_path
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/nlp/xlnet/run_squad.py
View file @
9df6a3d6
...
@@ -270,7 +270,8 @@ def main(unused_argv):
...
@@ -270,7 +270,8 @@ def main(unused_argv):
logging
.
info
(
"finishing reading pickle file..."
)
logging
.
info
(
"finishing reading pickle file..."
)
else
:
else
:
sp_model
=
spm
.
SentencePieceProcessor
()
sp_model
=
spm
.
SentencePieceProcessor
()
sp_model
.
Load
(
FLAGS
.
spiece_model_file
)
sp_model
.
LoadFromSerializedProto
(
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
spiece_model_file
,
"rb"
).
read
())
spm_basename
=
os
.
path
.
basename
(
FLAGS
.
spiece_model_file
)
spm_basename
=
os
.
path
.
basename
(
FLAGS
.
spiece_model_file
)
eval_features
=
squad_utils
.
create_eval_data
(
eval_features
=
squad_utils
.
create_eval_data
(
spm_basename
,
sp_model
,
eval_examples
,
FLAGS
.
max_seq_length
,
spm_basename
,
sp_model
,
eval_examples
,
FLAGS
.
max_seq_length
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment