Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9df6a3d6
Commit
9df6a3d6
authored
Nov 01, 2019
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Nov 01, 2019
Browse files
Add squad xlnet accuracy test
PiperOrigin-RevId: 277992916
parent
c14f5f4d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
89 additions
and
6 deletions
+89
-6
official/benchmark/xlnet_benchmark.py
official/benchmark/xlnet_benchmark.py
+87
-5
official/nlp/xlnet/run_squad.py
official/nlp/xlnet/run_squad.py
+2
-1
No files found.
official/benchmark/xlnet_benchmark.py
View file @
9df6a3d6
...
...
@@ -30,21 +30,23 @@ import tensorflow as tf
from
official.benchmark
import
bert_benchmark_utils
as
benchmark_utils
from
official.nlp.xlnet
import
run_classifier
from
official.nlp.xlnet
import
run_squad
# pylint: disable=line-too-long
PRETRAINED_CHECKPOINT_PATH
=
'gs://cloud-tpu-checkpoints/xlnet/large/xlnet_model-1'
CLASSIFIER_TRAIN_DATA_PATH
=
'gs://tf-perfzero-data/xlnet/imdb/spiece.model.len-512.train.tf_record'
CLASSIFIER_EVAL_DATA_PATH
=
'gs://tf-perfzero-data/xlnet/imdb/spiece.model.len-512.dev.eval.tf_record'
SQUAD_DATA_PATH
=
'gs://tf-perfzero-data/xlnet/squadv2_cased/'
# pylint: enable=line-too-long
FLAGS
=
flags
.
FLAGS
class
XLNet
Classify
BenchmarkBase
(
benchmark_utils
.
BertBenchmarkBase
):
class
XLNetBenchmarkBase
(
benchmark_utils
.
BertBenchmarkBase
):
"""Base class to hold methods common to test classes in the module."""
def
__init__
(
self
,
output_dir
=
None
):
super
(
XLNet
Classify
BenchmarkBase
,
self
).
__init__
(
output_dir
)
super
(
XLNetBenchmarkBase
,
self
).
__init__
(
output_dir
)
self
.
num_epochs
=
None
self
.
num_steps_per_epoch
=
None
...
...
@@ -53,9 +55,14 @@ class XLNetClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Starts XLNet classification task."""
run_classifier
.
main
(
unused_argv
=
None
)
@
flagsaver
.
flagsaver
def
_run_xlnet_squad
(
self
):
"""Starts XLNet classification task."""
run_squad
.
main
(
unused_argv
=
None
)
class
XLNetClassifyAccuracy
(
XLNetClassifyBenchmarkBase
):
"""Short accuracy test for XLNet model.
class
XLNetClassifyAccuracy
(
XLNetBenchmarkBase
):
"""Short accuracy test for XLNet classifier model.
Tests XLNet classification task model accuracy. The naming
convention of below test cases follow
...
...
@@ -93,7 +100,6 @@ class XLNetClassifyAccuracy(XLNetClassifyBenchmarkBase):
FLAGS
.
test_data_size
=
25024
FLAGS
.
train_batch_size
=
16
FLAGS
.
seq_len
=
512
FLAGS
.
reuse_len
=
256
FLAGS
.
mem_len
=
0
FLAGS
.
n_layer
=
24
FLAGS
.
d_model
=
1024
...
...
@@ -126,5 +132,81 @@ class XLNetClassifyAccuracy(XLNetClassifyBenchmarkBase):
self
.
_run_and_report_benchmark
(
summary_path
)
class
XLNetSquadAccuracy
(
XLNetBenchmarkBase
):
"""Short accuracy test for XLNet squad model.
Tests XLNet squad task model accuracy. The naming
convention of below test cases follow
`benchmark_(number of gpus)_gpu_(dataset type)` format.
"""
def
__init__
(
self
,
output_dir
=
None
,
**
kwargs
):
self
.
train_data_path
=
SQUAD_DATA_PATH
self
.
predict_file
=
os
.
path
.
join
(
SQUAD_DATA_PATH
,
"dev-v2.0.json"
)
self
.
test_data_path
=
os
.
path
.
join
(
SQUAD_DATA_PATH
,
"12048.eval.tf_record"
)
self
.
spiece_model_file
=
os
.
path
.
join
(
SQUAD_DATA_PATH
,
"spiece.cased.model"
)
self
.
pretrained_checkpoint_path
=
PRETRAINED_CHECKPOINT_PATH
super
(
XLNetSquadAccuracy
,
self
).
__init__
(
output_dir
=
output_dir
)
def
_run_and_report_benchmark
(
self
,
training_summary_path
,
min_accuracy
=
0.87
,
max_accuracy
=
0.89
):
"""Starts XLNet accuracy benchmark test."""
start_time_sec
=
time
.
time
()
self
.
_run_xlnet_squad
()
wall_time_sec
=
time
.
time
()
-
start_time_sec
with
tf
.
io
.
gfile
.
GFile
(
training_summary_path
,
'rb'
)
as
reader
:
summary
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
super
(
XLNetSquadAccuracy
,
self
).
_report_benchmark
(
stats
=
summary
,
wall_time_sec
=
wall_time_sec
,
min_accuracy
=
min_accuracy
,
max_accuracy
=
max_accuracy
)
def
_setup
(
self
):
super
(
XLNetSquadAccuracy
,
self
).
_setup
()
FLAGS
.
train_batch_size
=
16
FLAGS
.
seq_len
=
512
FLAGS
.
mem_len
=
0
FLAGS
.
n_layer
=
24
FLAGS
.
d_model
=
1024
FLAGS
.
d_embed
=
1024
FLAGS
.
n_head
=
16
FLAGS
.
d_head
=
64
FLAGS
.
d_inner
=
4096
FLAGS
.
untie_r
=
True
FLAGS
.
ff_activation
=
'gelu'
FLAGS
.
strategy_type
=
'mirror'
FLAGS
.
learning_rate
=
3e-5
FLAGS
.
train_steps
=
8000
FLAGS
.
warmup_steps
=
1000
FLAGS
.
iterations
=
1000
FLAGS
.
bi_data
=
False
FLAGS
.
init_checkpoint
=
self
.
pretrained_checkpoint_path
FLAGS
.
train_tfrecord_path
=
self
.
train_data_path
FLAGS
.
test_tfrecord_path
=
self
.
test_data_path
FLAGS
.
spiece_model_file
=
self
.
spiece_model_file
FLAGS
.
predict_file
=
self
.
predict_file
FLAGS
.
adam_epsilon
=
1e-6
FLAGS
.
lr_layer_decay_rate
=
0.75
def
benchmark_8_gpu_squadv2
(
self
):
"""Run XLNet model squad v2 accuracy test with 8 GPUs."""
self
.
_setup
()
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_8_gpu_squadv2'
)
FLAGS
.
predict_dir
=
FLAGS
.
model_dir
# Sets timer_callback to None as we do not use it now.
self
.
timer_callback
=
None
summary_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'summaries/training_summary.txt'
)
self
.
_run_and_report_benchmark
(
summary_path
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/nlp/xlnet/run_squad.py
View file @
9df6a3d6
...
...
@@ -270,7 +270,8 @@ def main(unused_argv):
logging
.
info
(
"finishing reading pickle file..."
)
else
:
sp_model
=
spm
.
SentencePieceProcessor
()
sp_model
.
Load
(
FLAGS
.
spiece_model_file
)
sp_model
.
LoadFromSerializedProto
(
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
spiece_model_file
,
"rb"
).
read
())
spm_basename
=
os
.
path
.
basename
(
FLAGS
.
spiece_model_file
)
eval_features
=
squad_utils
.
create_eval_data
(
spm_basename
,
sp_model
,
eval_examples
,
FLAGS
.
max_seq_length
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment