Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1f3247f4
Unverified
Commit
1f3247f4
authored
Mar 27, 2020
by
Ayushman Kumar
Committed by
GitHub
Mar 27, 2020
Browse files
Merge pull request #6 from tensorflow/master
Updated
parents
370a4c8d
0265f59c
Changes
85
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
591 additions
and
116 deletions
+591
-116
official/README.md
official/README.md
+2
-0
official/benchmark/bert_benchmark.py
official/benchmark/bert_benchmark.py
+0
-33
official/benchmark/bert_squad_benchmark.py
official/benchmark/bert_squad_benchmark.py
+5
-19
official/benchmark/models/resnet_cifar_main.py
official/benchmark/models/resnet_cifar_main.py
+5
-4
official/benchmark/ncf_keras_benchmark.py
official/benchmark/ncf_keras_benchmark.py
+2
-2
official/nlp/bert/README.md
official/nlp/bert/README.md
+5
-0
official/nlp/bert/common_flags.py
official/nlp/bert/common_flags.py
+1
-1
official/nlp/bert/model_training_utils.py
official/nlp/bert/model_training_utils.py
+1
-1
official/nlp/bert/model_training_utils_test.py
official/nlp/bert/model_training_utils_test.py
+3
-2
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+8
-3
official/nlp/bert/run_pretraining.py
official/nlp/bert/run_pretraining.py
+1
-2
official/nlp/bert/run_squad.py
official/nlp/bert/run_squad.py
+37
-5
official/nlp/bert/run_squad_helper.py
official/nlp/bert/run_squad_helper.py
+65
-20
official/nlp/bert/squad_evaluate_v1_1.py
official/nlp/bert/squad_evaluate_v1_1.py
+108
-0
official/nlp/bert/squad_evaluate_v2_0.py
official/nlp/bert/squad_evaluate_v2_0.py
+252
-0
official/nlp/data/squad_lib.py
official/nlp/data/squad_lib.py
+32
-7
official/nlp/data/squad_lib_sp.py
official/nlp/data/squad_lib_sp.py
+34
-8
official/nlp/optimization.py
official/nlp/optimization.py
+8
-8
official/nlp/transformer/misc.py
official/nlp/transformer/misc.py
+4
-1
official/nlp/transformer/transformer_main.py
official/nlp/transformer/transformer_main.py
+18
-0
No files found.
official/README.md
View file @
1f3247f4
...
...
@@ -80,6 +80,8 @@ installable Official Models package. This is being tracked in
### Natural Language Processing
*
[
albert
](
nlp/albert
)
: A Lite BERT for Self-supervised Learning of Language
Representations.
*
[
bert
](
nlp/bert
)
: A powerful pre-trained language representation model:
BERT, which stands for Bidirectional Encoder Representations from
Transformers.
...
...
official/benchmark/bert_benchmark.py
View file @
1f3247f4
...
...
@@ -212,39 +212,6 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
'summaries/training_summary.txt'
)
self
.
_run_and_report_benchmark
(
summary_path
,
use_ds
=
False
)
def
benchmark_2_gpu_mrpc
(
self
):
"""Test BERT model performance with 2 GPUs."""
self
.
_setup
()
self
.
num_gpus
=
2
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_2_gpu_mrpc'
)
FLAGS
.
train_data_path
=
self
.
train_data_path
FLAGS
.
eval_data_path
=
self
.
eval_data_path
FLAGS
.
input_meta_data_path
=
self
.
input_meta_data_path
FLAGS
.
bert_config_file
=
self
.
bert_config_file
FLAGS
.
train_batch_size
=
8
FLAGS
.
eval_batch_size
=
8
summary_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'summaries/training_summary.txt'
)
self
.
_run_and_report_benchmark
(
summary_path
)
def
benchmark_4_gpu_mrpc
(
self
):
"""Test BERT model performance with 4 GPUs."""
self
.
_setup
()
self
.
num_gpus
=
4
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'benchmark_4_gpu_mrpc'
)
FLAGS
.
train_data_path
=
self
.
train_data_path
FLAGS
.
eval_data_path
=
self
.
eval_data_path
FLAGS
.
input_meta_data_path
=
self
.
input_meta_data_path
FLAGS
.
bert_config_file
=
self
.
bert_config_file
FLAGS
.
train_batch_size
=
16
summary_path
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'summaries/training_summary.txt'
)
self
.
_run_and_report_benchmark
(
summary_path
)
def
benchmark_8_gpu_mrpc
(
self
):
"""Test BERT model performance with 8 GPUs."""
...
...
official/benchmark/bert_squad_benchmark.py
View file @
1f3247f4
...
...
@@ -24,12 +24,12 @@ import time
# pylint: disable=g-bad-import-order
from
absl
import
flags
from
absl
import
logging
from
absl.testing
import
flagsaver
import
tensorflow
as
tf
# pylint: enable=g-bad-import-order
from
official.benchmark
import
bert_benchmark_utils
as
benchmark_utils
from
official.benchmark
import
squad_evaluate_v1_1
from
official.nlp.bert
import
run_squad
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
...
...
@@ -70,18 +70,6 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
input_meta_data_path
,
'rb'
)
as
reader
:
return
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
def
_read_predictions_dataset_from_file
(
self
):
"""Reads the predictions dataset from a file."""
with
tf
.
io
.
gfile
.
GFile
(
SQUAD_PREDICT_FILE
,
'r'
)
as
reader
:
dataset_json
=
json
.
load
(
reader
)
return
dataset_json
[
'data'
]
def
_read_predictions_from_file
(
self
):
"""Reads the predictions from a file."""
predictions_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'predictions.json'
)
with
tf
.
io
.
gfile
.
GFile
(
predictions_file
,
'r'
)
as
reader
:
return
json
.
load
(
reader
)
def
_get_distribution_strategy
(
self
,
ds_type
=
'mirrored'
):
"""Gets the distribution strategy.
...
...
@@ -135,12 +123,10 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
input_meta_data
=
self
.
_read_input_meta_data_from_file
()
strategy
=
self
.
_get_distribution_strategy
(
ds_type
)
run_squad
.
predict_squad
(
strategy
=
strategy
,
input_meta_data
=
input_meta_data
)
dataset
=
self
.
_read_predictions_dataset_from_file
()
predictions
=
self
.
_read_predictions_from_file
()
eval_metrics
=
squad_evaluate_v1_1
.
evaluate
(
dataset
,
predictions
)
if
input_meta_data
.
get
(
'version_2_with_negative'
,
False
):
logging
.
error
(
'In memory evaluation result for SQuAD v2 is not accurate'
)
eval_metrics
=
run_squad
.
eval_squad
(
strategy
=
strategy
,
input_meta_data
=
input_meta_data
)
# Use F1 score as reported evaluation metric.
self
.
eval_metrics
=
eval_metrics
[
'f1'
]
...
...
official/benchmark/models/resnet_cifar_main.py
View file @
1f3247f4
...
...
@@ -20,6 +20,7 @@ from __future__ import print_function
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
numpy
as
np
import
tensorflow
as
tf
from
official.benchmark.models
import
resnet_cifar_model
...
...
@@ -100,7 +101,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
if
lr
!=
self
.
prev_lr
:
self
.
model
.
optimizer
.
learning_rate
=
lr
# lr should be a float here
self
.
prev_lr
=
lr
tf
.
compat
.
v1
.
logging
.
debug
(
logging
.
debug
(
'Epoch %05d Batch %05d: LearningRateBatchScheduler '
'change learning rate to %s.'
,
self
.
epochs
,
batch
,
lr
)
...
...
@@ -137,8 +138,8 @@ def run(flags_obj):
data_format
=
flags_obj
.
data_format
if
data_format
is
None
:
data_format
=
(
'channels_first'
if
tf
.
test
.
is_built_with_cuda
()
else
'channels_last'
)
data_format
=
(
'channels_first'
if
tf
.
config
.
list_physical_devices
(
'GPU'
)
else
'channels_last'
)
tf
.
keras
.
backend
.
set_image_data_format
(
data_format
)
strategy
=
distribution_utils
.
get_distribution_strategy
(
...
...
@@ -280,6 +281,6 @@ def main(_):
if
__name__
==
'__main__'
:
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
define_cifar_flags
()
app
.
run
(
main
)
official/benchmark/ncf_keras_benchmark.py
View file @
1f3247f4
...
...
@@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================
"""Executes Keras benchmarks and accuracy tests."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
@@ -22,6 +21,7 @@ import os
import
time
from
absl
import
flags
from
absl
import
logging
from
absl.testing
import
flagsaver
import
tensorflow
as
tf
...
...
@@ -51,7 +51,7 @@ class NCFKerasBenchmarkBase(tf.test.Benchmark):
def
_setup
(
self
):
"""Sets up and resets flags before each test."""
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
if
NCFKerasBenchmarkBase
.
local_flags
is
None
:
ncf_common
.
define_ncf_flags
()
# Loads flags to get defaults to then override. List cannot be empty.
...
...
official/nlp/bert/README.md
View file @
1f3247f4
...
...
@@ -269,6 +269,7 @@ python run_classifier.py \
--init_checkpoint
=
${
BERT_DIR
}
/bert_model.ckpt
\
--train_batch_size
=
32
\
--eval_batch_size
=
32
\
--steps_per_loop
=
1000
\
--learning_rate
=
2e-5
\
--num_train_epochs
=
3
\
--model_dir
=
${
MODEL_DIR
}
\
...
...
@@ -276,6 +277,10 @@ python run_classifier.py \
--tpu
=
grpc://
${
TPU_IP_ADDRESS
}
:8470
```
Note that, we specify
`steps_per_loop=1000`
for TPU, because running a loop of
training steps inside a
`tf.function`
can significantly increase TPU utilization
and callbacks will not be called inside the loop.
### SQuAD 1.1
The Stanford Question Answering Dataset (SQuAD) is a popular question answering
...
...
official/nlp/bert/common_flags.py
View file @
1f3247f4
...
...
@@ -57,7 +57,7 @@ def define_common_bert_flags():
flags
.
DEFINE_integer
(
'num_train_epochs'
,
3
,
'Total number of training epochs to perform.'
)
flags
.
DEFINE_integer
(
'steps_per_loop'
,
200
,
'steps_per_loop'
,
1
,
'Number of steps per graph-mode loop. Only training step '
'happens inside the loop. Callbacks will not be called '
'inside.'
)
...
...
official/
modeling
/model_training_utils.py
→
official/
nlp/bert
/model_training_utils.py
View file @
1f3247f4
...
...
@@ -415,7 +415,7 @@ def run_customized_training_loop(
# Runs several steps in the host while loop.
steps
=
steps_to_run
(
current_step
,
steps_per_epoch
,
steps_per_loop
)
if
tf
.
test
.
is_built_with_cuda
(
):
if
tf
.
config
.
list_physical_devices
(
'GPU'
):
# TODO(zongweiz): merge with train_steps once tf.while_loop
# GPU performance bugs are fixed.
for
_
in
range
(
steps
):
...
...
official/
modeling
/model_training_utils_test.py
→
official/
nlp/bert
/model_training_utils_test.py
View file @
1f3247f4
...
...
@@ -20,6 +20,7 @@ from __future__ import print_function
import
os
from
absl
import
logging
from
absl.testing
import
parameterized
from
absl.testing.absltest
import
mock
import
numpy
as
np
...
...
@@ -27,7 +28,7 @@ import tensorflow as tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.
modeling
import
model_training_utils
from
official.
nlp.bert
import
model_training_utils
def
eager_strategy_combinations
():
...
...
@@ -125,7 +126,7 @@ def summaries_with_matching_keyword(keyword, summary_dir):
if
event
.
summary
is
not
None
:
for
value
in
event
.
summary
.
value
:
if
keyword
in
value
.
tag
:
tf
.
compat
.
v1
.
logging
.
error
(
event
)
logging
.
error
(
event
)
yield
event
.
summary
...
...
official/nlp/bert/run_classifier.py
View file @
1f3247f4
...
...
@@ -25,8 +25,6 @@ from absl import app
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
from
official.modeling
import
model_training_utils
from
official.modeling
import
performance
from
official.nlp
import
optimization
from
official.nlp.bert
import
bert_models
...
...
@@ -34,6 +32,7 @@ from official.nlp.bert import common_flags
from
official.nlp.bert
import
configs
as
bert_configs
from
official.nlp.bert
import
input_pipeline
from
official.nlp.bert
import
model_saving_utils
from
official.nlp.bert
import
model_training_utils
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
...
...
@@ -156,6 +155,7 @@ def run_bert_classifier(strategy,
init_checkpoint
,
epochs
,
steps_per_epoch
,
steps_per_loop
,
eval_steps
,
custom_callbacks
=
custom_callbacks
)
...
...
@@ -189,6 +189,7 @@ def run_keras_compile_fit(model_dir,
init_checkpoint
,
epochs
,
steps_per_epoch
,
steps_per_loop
,
eval_steps
,
custom_callbacks
=
None
):
"""Runs BERT classifier model using Keras compile/fit API."""
...
...
@@ -203,7 +204,11 @@ def run_keras_compile_fit(model_dir,
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
sub_model
)
checkpoint
.
restore
(
init_checkpoint
).
assert_existing_objects_matched
()
bert_model
.
compile
(
optimizer
=
optimizer
,
loss
=
loss_fn
,
metrics
=
[
metric_fn
()])
bert_model
.
compile
(
optimizer
=
optimizer
,
loss
=
loss_fn
,
metrics
=
[
metric_fn
()],
experimental_steps_per_execution
=
steps_per_loop
)
summary_dir
=
os
.
path
.
join
(
model_dir
,
'summaries'
)
summary_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
summary_dir
)
...
...
official/nlp/bert/run_pretraining.py
View file @
1f3247f4
...
...
@@ -22,14 +22,13 @@ from absl import flags
from
absl
import
logging
import
gin
import
tensorflow
as
tf
from
official.modeling
import
model_training_utils
from
official.modeling
import
performance
from
official.nlp
import
optimization
from
official.nlp.bert
import
bert_models
from
official.nlp.bert
import
common_flags
from
official.nlp.bert
import
configs
from
official.nlp.bert
import
input_pipeline
from
official.nlp.bert
import
model_training_utils
from
official.utils.misc
import
distribution_utils
...
...
official/nlp/bert/run_squad.py
View file @
1f3247f4
...
...
@@ -19,9 +19,13 @@ from __future__ import division
from
__future__
import
print_function
import
json
import
os
import
tempfile
import
time
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
from
official.nlp.bert
import
configs
as
bert_configs
...
...
@@ -52,12 +56,22 @@ def train_squad(strategy,
def
predict_squad
(
strategy
,
input_meta_data
):
"""Makes predictions for
a
squad dataset."""
"""Makes predictions for
the
squad dataset."""
bert_config
=
bert_configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
run_squad_helper
.
predict_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib_wp
)
run_squad_helper
.
predict_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib_wp
)
def
eval_squad
(
strategy
,
input_meta_data
):
"""Evaluate on the squad dataset."""
bert_config
=
bert_configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
eval_metrics
=
run_squad_helper
.
eval_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib_wp
)
return
eval_metrics
def
export_squad
(
model_export_path
,
input_meta_data
):
...
...
@@ -93,7 +107,8 @@ def main(_):
num_gpus
=
FLAGS
.
num_gpus
,
all_reduce_alg
=
FLAGS
.
all_reduce_alg
,
tpu_address
=
FLAGS
.
tpu
)
if
FLAGS
.
mode
in
(
'train'
,
'train_and_predict'
):
if
'train'
in
FLAGS
.
mode
:
if
FLAGS
.
log_steps
:
custom_callbacks
=
[
keras_utils
.
TimeHistory
(
batch_size
=
FLAGS
.
train_batch_size
,
...
...
@@ -109,8 +124,25 @@ def main(_):
custom_callbacks
=
custom_callbacks
,
run_eagerly
=
FLAGS
.
run_eagerly
,
)
if
FLAGS
.
mode
in
(
'predict'
,
'train_and_predict'
)
:
if
'predict'
in
FLAGS
.
mode
:
predict_squad
(
strategy
,
input_meta_data
)
if
'eval'
in
FLAGS
.
mode
:
eval_metrics
=
eval_squad
(
strategy
,
input_meta_data
)
f1_score
=
eval_metrics
[
'final_f1'
]
logging
.
info
(
'SQuAD eval F1-score: %f'
,
f1_score
)
if
(
not
strategy
)
or
strategy
.
extended
.
should_save_summary
:
summary_dir
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'summaries'
)
else
:
summary_dir
=
tempfile
.
mkdtemp
()
summary_writer
=
tf
.
summary
.
create_file_writer
(
os
.
path
.
join
(
summary_dir
,
'eval'
))
with
summary_writer
.
as_default
():
# TODO(lehou): write to the correct step number.
tf
.
summary
.
scalar
(
'F1-score'
,
f1_score
,
step
=
0
)
summary_writer
.
flush
()
# Wait for some time, for the depending mldash/tensorboard jobs to finish
# exporting the final F1-score.
time
.
sleep
(
60
)
if
__name__
==
'__main__'
:
...
...
official/nlp/bert/run_squad_helper.py
View file @
1f3247f4
...
...
@@ -18,18 +18,20 @@ from __future__ import division
from
__future__
import
print_function
import
collections
import
json
import
os
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
from
official.modeling
import
model_training_utils
from
official.modeling
import
performance
from
official.nlp
import
optimization
from
official.nlp.bert
import
bert_models
from
official.nlp.bert
import
common_flags
from
official.nlp.bert
import
input_pipeline
from
official.nlp.bert
import
model_saving_utils
from
official.nlp.bert
import
model_training_utils
from
official.nlp.bert
import
squad_evaluate_v1_1
from
official.nlp.bert
import
squad_evaluate_v2_0
from
official.nlp.data
import
squad_lib_sp
from
official.utils.misc
import
keras_utils
...
...
@@ -37,11 +39,15 @@ from official.utils.misc import keras_utils
def
define_common_squad_flags
():
"""Defines common flags used by SQuAD tasks."""
flags
.
DEFINE_enum
(
'mode'
,
'train_and_predict'
,
[
'train_and_predict'
,
'train'
,
'predict'
,
'export_only'
],
'One of {"train_and_predict", "train", "predict", "export_only"}. '
'`train_and_predict`: both train and predict to a json file. '
'mode'
,
'train_and_eval'
,
[
'train_and_eval'
,
'train_and_predict'
,
'train'
,
'eval'
,
'predict'
,
'export_only'
],
'One of {"train_and_eval", "train_and_predict", '
'"train", "eval", "predict", "export_only"}. '
'`train_and_eval`: train & predict to json files & compute eval metrics. '
'`train_and_predict`: train & predict to json files. '
'`train`: only trains the model. '
'`eval`: predict answers from squad json file & compute eval metrics. '
'`predict`: predict answers from the squad json file. '
'`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`.'
)
...
...
@@ -271,7 +277,8 @@ def train_squad(strategy,
post_allreduce_callbacks
=
[
clip_by_global_norm_callback
])
def
predict_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib
):
def
prediction_output_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib
):
"""Makes predictions for a squad dataset."""
doc_stride
=
input_meta_data
[
'doc_stride'
]
max_query_length
=
input_meta_data
[
'max_query_length'
]
...
...
@@ -322,23 +329,61 @@ def predict_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib):
all_results
=
predict_squad_customized
(
strategy
,
input_meta_data
,
bert_config
,
eval_writer
.
filename
,
num_steps
)
output_prediction_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'predictions.json'
)
output_nbest_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'nbest_predictions.json'
)
output_null_log_odds_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'null_odds.json'
)
squad_lib
.
write_predictions
(
all_predictions
,
all_nbest_json
,
scores_diff_json
=
(
squad_lib
.
postprocess_output
(
eval_examples
,
eval_features
,
all_results
,
FLAGS
.
n_best_size
,
FLAGS
.
max_answer_length
,
FLAGS
.
do_lower_case
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
version_2_with_negative
=
version_2_with_negative
,
null_score_diff_threshold
=
FLAGS
.
null_score_diff_threshold
,
verbose
=
FLAGS
.
verbose_logging
)
verbose
=
FLAGS
.
verbose_logging
))
return
all_predictions
,
all_nbest_json
,
scores_diff_json
def
dump_to_files
(
all_predictions
,
all_nbest_json
,
scores_diff_json
,
squad_lib
,
version_2_with_negative
):
"""Save output to json files."""
output_prediction_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'predictions.json'
)
output_nbest_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'nbest_predictions.json'
)
output_null_log_odds_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'null_odds.json'
)
logging
.
info
(
'Writing predictions to: %s'
,
(
output_prediction_file
))
logging
.
info
(
'Writing nbest to: %s'
,
(
output_nbest_file
))
squad_lib
.
write_to_json_files
(
all_predictions
,
output_prediction_file
)
squad_lib
.
write_to_json_files
(
all_nbest_json
,
output_nbest_file
)
if
version_2_with_negative
:
squad_lib
.
write_to_json_files
(
scores_diff_json
,
output_null_log_odds_file
)
def
predict_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib
):
"""Get prediction results and evaluate them to hard drive."""
all_predictions
,
all_nbest_json
,
scores_diff_json
=
prediction_output_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib
)
dump_to_files
(
all_predictions
,
all_nbest_json
,
scores_diff_json
,
squad_lib
,
input_meta_data
.
get
(
'version_2_with_negative'
,
False
))
def
eval_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib
):
"""Get prediction results and evaluate them against ground truth."""
all_predictions
,
all_nbest_json
,
scores_diff_json
=
prediction_output_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib
)
dump_to_files
(
all_predictions
,
all_nbest_json
,
scores_diff_json
,
squad_lib
,
input_meta_data
.
get
(
'version_2_with_negative'
,
False
))
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
predict_file
,
'r'
)
as
reader
:
dataset_json
=
json
.
load
(
reader
)
pred_dataset
=
dataset_json
[
'data'
]
if
input_meta_data
.
get
(
'version_2_with_negative'
,
False
):
eval_metrics
=
squad_evaluate_v2_0
.
evaluate
(
pred_dataset
,
all_predictions
,
scores_diff_json
)
else
:
eval_metrics
=
squad_evaluate_v1_1
.
evaluate
(
pred_dataset
,
all_predictions
)
return
eval_metrics
def
export_squad
(
model_export_path
,
input_meta_data
,
bert_config
):
...
...
official/nlp/bert/squad_evaluate_v1_1.py
0 → 100644
View file @
1f3247f4
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation of SQuAD predictions (version 1.1).
The functions are copied from
https://worksheets.codalab.org/rest/bundles/0xbcd57bee090b421c982906709c8c27e1/contents/blob/.
The SQuAD dataset is described in this paper:
SQuAD: 100,000+ Questions for Machine Comprehension of Text
Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang
https://nlp.stanford.edu/pubs/rajpurkar2016squad.pdf
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
re
import
string
# pylint: disable=g-bad-import-order
from
absl
import
logging
# pylint: enable=g-bad-import-order
def
_normalize_answer
(
s
):
"""Lowers text and remove punctuation, articles and extra whitespace."""
def
remove_articles
(
text
):
return
re
.
sub
(
r
"\b(a|an|the)\b"
,
" "
,
text
)
def
white_space_fix
(
text
):
return
" "
.
join
(
text
.
split
())
def
remove_punc
(
text
):
exclude
=
set
(
string
.
punctuation
)
return
""
.
join
(
ch
for
ch
in
text
if
ch
not
in
exclude
)
def
lower
(
text
):
return
text
.
lower
()
return
white_space_fix
(
remove_articles
(
remove_punc
(
lower
(
s
))))
def
_f1_score
(
prediction
,
ground_truth
):
"""Computes F1 score by comparing prediction to ground truth."""
prediction_tokens
=
_normalize_answer
(
prediction
).
split
()
ground_truth_tokens
=
_normalize_answer
(
ground_truth
).
split
()
prediction_counter
=
collections
.
Counter
(
prediction_tokens
)
ground_truth_counter
=
collections
.
Counter
(
ground_truth_tokens
)
common
=
prediction_counter
&
ground_truth_counter
num_same
=
sum
(
common
.
values
())
if
num_same
==
0
:
return
0
precision
=
1.0
*
num_same
/
len
(
prediction_tokens
)
recall
=
1.0
*
num_same
/
len
(
ground_truth_tokens
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
return
f1
def
_exact_match_score
(
prediction
,
ground_truth
):
"""Checks if predicted answer exactly matches ground truth answer."""
return
_normalize_answer
(
prediction
)
==
_normalize_answer
(
ground_truth
)
def
_metric_max_over_ground_truths
(
metric_fn
,
prediction
,
ground_truths
):
"""Computes the max over all metric scores."""
scores_for_ground_truths
=
[]
for
ground_truth
in
ground_truths
:
score
=
metric_fn
(
prediction
,
ground_truth
)
scores_for_ground_truths
.
append
(
score
)
return
max
(
scores_for_ground_truths
)
def
evaluate
(
dataset
,
predictions
):
"""Evaluates predictions for a dataset."""
f1
=
exact_match
=
total
=
0
for
article
in
dataset
:
for
paragraph
in
article
[
"paragraphs"
]:
for
qa
in
paragraph
[
"qas"
]:
total
+=
1
if
qa
[
"id"
]
not
in
predictions
:
message
=
"Unanswered question "
+
qa
[
"id"
]
+
" will receive score 0."
logging
.
error
(
message
)
continue
ground_truths
=
[
entry
[
"text"
]
for
entry
in
qa
[
"answers"
]]
prediction
=
predictions
[
qa
[
"id"
]]
exact_match
+=
_metric_max_over_ground_truths
(
_exact_match_score
,
prediction
,
ground_truths
)
f1
+=
_metric_max_over_ground_truths
(
_f1_score
,
prediction
,
ground_truths
)
exact_match
=
exact_match
/
total
f1
=
f1
/
total
return
{
"exact_match"
:
exact_match
,
"final_f1"
:
f1
}
official/nlp/bert/squad_evaluate_v2_0.py
0 → 100644
View file @
1f3247f4
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation script for SQuAD version 2.0.
The functions are copied and modified from
https://raw.githubusercontent.com/white127/SQUAD-2.0-bidaf/master/evaluate-v2.0.py
In addition to basic functionality, we also compute additional statistics and
plot precision-recall curves if an additional na_prob.json file is provided.
This file is expected to map question ID's to the model's predicted probability
that a question is unanswerable.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
re
import
string
from
absl
import
logging
def
_make_qid_to_has_ans
(
dataset
):
qid_to_has_ans
=
{}
for
article
in
dataset
:
for
p
in
article
[
'paragraphs'
]:
for
qa
in
p
[
'qas'
]:
qid_to_has_ans
[
qa
[
'id'
]]
=
bool
(
qa
[
'answers'
])
return
qid_to_has_ans
def
_normalize_answer
(
s
):
"""Lower text and remove punctuation, articles and extra whitespace."""
def
remove_articles
(
text
):
regex
=
re
.
compile
(
r
'\b(a|an|the)\b'
,
re
.
UNICODE
)
return
re
.
sub
(
regex
,
' '
,
text
)
def
white_space_fix
(
text
):
return
' '
.
join
(
text
.
split
())
def
remove_punc
(
text
):
exclude
=
set
(
string
.
punctuation
)
return
''
.
join
(
ch
for
ch
in
text
if
ch
not
in
exclude
)
def
lower
(
text
):
return
text
.
lower
()
return
white_space_fix
(
remove_articles
(
remove_punc
(
lower
(
s
))))
def
_get_tokens
(
s
):
if
not
s
:
return
[]
return
_normalize_answer
(
s
).
split
()
def
_compute_exact
(
a_gold
,
a_pred
):
return
int
(
_normalize_answer
(
a_gold
)
==
_normalize_answer
(
a_pred
))
def
_compute_f1
(
a_gold
,
a_pred
):
"""Compute F1-score."""
gold_toks
=
_get_tokens
(
a_gold
)
pred_toks
=
_get_tokens
(
a_pred
)
common
=
collections
.
Counter
(
gold_toks
)
&
collections
.
Counter
(
pred_toks
)
num_same
=
sum
(
common
.
values
())
if
not
gold_toks
or
not
pred_toks
:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return
int
(
gold_toks
==
pred_toks
)
if
num_same
==
0
:
return
0
precision
=
1.0
*
num_same
/
len
(
pred_toks
)
recall
=
1.0
*
num_same
/
len
(
gold_toks
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
return
f1
def
_get_raw_scores
(
dataset
,
predictions
):
"""Compute raw scores."""
exact_scores
=
{}
f1_scores
=
{}
for
article
in
dataset
:
for
p
in
article
[
'paragraphs'
]:
for
qa
in
p
[
'qas'
]:
qid
=
qa
[
'id'
]
gold_answers
=
[
a
[
'text'
]
for
a
in
qa
[
'answers'
]
if
_normalize_answer
(
a
[
'text'
])]
if
not
gold_answers
:
# For unanswerable questions, only correct answer is empty string
gold_answers
=
[
''
]
if
qid
not
in
predictions
:
logging
.
error
(
'Missing prediction for %s'
,
qid
)
continue
a_pred
=
predictions
[
qid
]
# Take max over all gold answers
exact_scores
[
qid
]
=
max
(
_compute_exact
(
a
,
a_pred
)
for
a
in
gold_answers
)
f1_scores
[
qid
]
=
max
(
_compute_f1
(
a
,
a_pred
)
for
a
in
gold_answers
)
return
exact_scores
,
f1_scores
def
_apply_no_ans_threshold
(
scores
,
na_probs
,
qid_to_has_ans
,
na_prob_thresh
=
1.0
):
new_scores
=
{}
for
qid
,
s
in
scores
.
items
():
pred_na
=
na_probs
[
qid
]
>
na_prob_thresh
if
pred_na
:
new_scores
[
qid
]
=
float
(
not
qid_to_has_ans
[
qid
])
else
:
new_scores
[
qid
]
=
s
return
new_scores
def
_make_eval_dict
(
exact_scores
,
f1_scores
,
qid_list
=
None
):
"""Make evaluation result dictionary."""
if
not
qid_list
:
total
=
len
(
exact_scores
)
return
collections
.
OrderedDict
([
(
'exact'
,
100.0
*
sum
(
exact_scores
.
values
())
/
total
),
(
'f1'
,
100.0
*
sum
(
f1_scores
.
values
())
/
total
),
(
'total'
,
total
),
])
else
:
total
=
len
(
qid_list
)
return
collections
.
OrderedDict
([
(
'exact'
,
100.0
*
sum
(
exact_scores
[
k
]
for
k
in
qid_list
)
/
total
),
(
'f1'
,
100.0
*
sum
(
f1_scores
[
k
]
for
k
in
qid_list
)
/
total
),
(
'total'
,
total
),
])
def
_merge_eval
(
main_eval
,
new_eval
,
prefix
):
for
k
in
new_eval
:
main_eval
[
'%s_%s'
%
(
prefix
,
k
)]
=
new_eval
[
k
]
def
_make_precision_recall_eval
(
scores
,
na_probs
,
num_true_pos
,
qid_to_has_ans
):
"""Make evaluation dictionary containing average recision recall."""
qid_list
=
sorted
(
na_probs
,
key
=
lambda
k
:
na_probs
[
k
])
true_pos
=
0.0
cur_p
=
1.0
cur_r
=
0.0
precisions
=
[
1.0
]
recalls
=
[
0.0
]
avg_prec
=
0.0
for
i
,
qid
in
enumerate
(
qid_list
):
if
qid_to_has_ans
[
qid
]:
true_pos
+=
scores
[
qid
]
cur_p
=
true_pos
/
float
(
i
+
1
)
cur_r
=
true_pos
/
float
(
num_true_pos
)
if
i
==
len
(
qid_list
)
-
1
or
na_probs
[
qid
]
!=
na_probs
[
qid_list
[
i
+
1
]]:
# i.e., if we can put a threshold after this point
avg_prec
+=
cur_p
*
(
cur_r
-
recalls
[
-
1
])
precisions
.
append
(
cur_p
)
recalls
.
append
(
cur_r
)
return
{
'ap'
:
100.0
*
avg_prec
}
def
_run_precision_recall_analysis
(
main_eval
,
exact_raw
,
f1_raw
,
na_probs
,
qid_to_has_ans
):
"""Run precision recall analysis and return result dictionary."""
num_true_pos
=
sum
(
1
for
v
in
qid_to_has_ans
.
values
()
if
v
)
if
num_true_pos
==
0
:
return
pr_exact
=
_make_precision_recall_eval
(
exact_raw
,
na_probs
,
num_true_pos
,
qid_to_has_ans
)
pr_f1
=
_make_precision_recall_eval
(
f1_raw
,
na_probs
,
num_true_pos
,
qid_to_has_ans
)
oracle_scores
=
{
k
:
float
(
v
)
for
k
,
v
in
qid_to_has_ans
.
items
()}
pr_oracle
=
_make_precision_recall_eval
(
oracle_scores
,
na_probs
,
num_true_pos
,
qid_to_has_ans
)
_merge_eval
(
main_eval
,
pr_exact
,
'pr_exact'
)
_merge_eval
(
main_eval
,
pr_f1
,
'pr_f1'
)
_merge_eval
(
main_eval
,
pr_oracle
,
'pr_oracle'
)
def
_find_best_thresh
(
predictions
,
scores
,
na_probs
,
qid_to_has_ans
):
"""Find the best threshold for no answer probability."""
num_no_ans
=
sum
(
1
for
k
in
qid_to_has_ans
if
not
qid_to_has_ans
[
k
])
cur_score
=
num_no_ans
best_score
=
cur_score
best_thresh
=
0.0
qid_list
=
sorted
(
na_probs
,
key
=
lambda
k
:
na_probs
[
k
])
for
qid
in
qid_list
:
if
qid
not
in
scores
:
continue
if
qid_to_has_ans
[
qid
]:
diff
=
scores
[
qid
]
else
:
if
predictions
[
qid
]:
diff
=
-
1
else
:
diff
=
0
cur_score
+=
diff
if
cur_score
>
best_score
:
best_score
=
cur_score
best_thresh
=
na_probs
[
qid
]
return
100.0
*
best_score
/
len
(
scores
),
best_thresh
def
_find_all_best_thresh
(
main_eval
,
predictions
,
exact_raw
,
f1_raw
,
na_probs
,
qid_to_has_ans
):
best_exact
,
exact_thresh
=
_find_best_thresh
(
predictions
,
exact_raw
,
na_probs
,
qid_to_has_ans
)
best_f1
,
f1_thresh
=
_find_best_thresh
(
predictions
,
f1_raw
,
na_probs
,
qid_to_has_ans
)
main_eval
[
'final_exact'
]
=
best_exact
main_eval
[
'final_exact_thresh'
]
=
exact_thresh
main_eval
[
'final_f1'
]
=
best_f1
main_eval
[
'final_f1_thresh'
]
=
f1_thresh
def
evaluate
(
dataset
,
predictions
,
na_probs
=
None
):
"""Evaluate prediction results."""
new_orig_data
=
[]
for
article
in
dataset
:
for
p
in
article
[
'paragraphs'
]:
for
qa
in
p
[
'qas'
]:
if
qa
[
'id'
]
in
predictions
:
new_para
=
{
'qas'
:
[
qa
]}
new_article
=
{
'paragraphs'
:
[
new_para
]}
new_orig_data
.
append
(
new_article
)
dataset
=
new_orig_data
if
na_probs
is
None
:
na_probs
=
{
k
:
0.0
for
k
in
predictions
}
qid_to_has_ans
=
_make_qid_to_has_ans
(
dataset
)
# maps qid to True/False
has_ans_qids
=
[
k
for
k
,
v
in
qid_to_has_ans
.
items
()
if
v
]
no_ans_qids
=
[
k
for
k
,
v
in
qid_to_has_ans
.
items
()
if
not
v
]
exact_raw
,
f1_raw
=
_get_raw_scores
(
dataset
,
predictions
)
exact_thresh
=
_apply_no_ans_threshold
(
exact_raw
,
na_probs
,
qid_to_has_ans
)
f1_thresh
=
_apply_no_ans_threshold
(
f1_raw
,
na_probs
,
qid_to_has_ans
)
out_eval
=
_make_eval_dict
(
exact_thresh
,
f1_thresh
)
if
has_ans_qids
:
has_ans_eval
=
_make_eval_dict
(
exact_thresh
,
f1_thresh
,
qid_list
=
has_ans_qids
)
_merge_eval
(
out_eval
,
has_ans_eval
,
'HasAns'
)
if
no_ans_qids
:
no_ans_eval
=
_make_eval_dict
(
exact_thresh
,
f1_thresh
,
qid_list
=
no_ans_qids
)
_merge_eval
(
out_eval
,
no_ans_eval
,
'NoAns'
)
_find_all_best_thresh
(
out_eval
,
predictions
,
exact_raw
,
f1_raw
,
na_probs
,
qid_to_has_ans
)
_run_precision_recall_analysis
(
out_eval
,
exact_raw
,
f1_raw
,
na_probs
,
qid_to_has_ans
)
return
out_eval
official/nlp/data/squad_lib.py
View file @
1f3247f4
...
...
@@ -506,6 +506,34 @@ def write_predictions(all_examples,
logging
.
info
(
"Writing predictions to: %s"
,
(
output_prediction_file
))
logging
.
info
(
"Writing nbest to: %s"
,
(
output_nbest_file
))
all_predictions
,
all_nbest_json
,
scores_diff_json
=
(
postprocess_output
(
all_examples
=
all_examples
,
all_features
=
all_features
,
all_results
=
all_results
,
n_best_size
=
n_best_size
,
max_answer_length
=
max_answer_length
,
do_lower_case
=
do_lower_case
,
version_2_with_negative
=
version_2_with_negative
,
null_score_diff_threshold
=
null_score_diff_threshold
,
verbose
=
verbose
))
write_to_json_files
(
all_predictions
,
output_prediction_file
)
write_to_json_files
(
all_nbest_json
,
output_nbest_file
)
if
version_2_with_negative
:
write_to_json_files
(
scores_diff_json
,
output_null_log_odds_file
)
def
postprocess_output
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
max_answer_length
,
do_lower_case
,
version_2_with_negative
=
False
,
null_score_diff_threshold
=
0.0
,
verbose
=
False
):
"""Postprocess model output, to form predicton results."""
example_index_to_features
=
collections
.
defaultdict
(
list
)
for
feature
in
all_features
:
example_index_to_features
[
feature
.
example_index
].
append
(
feature
)
...
...
@@ -676,15 +704,12 @@ def write_predictions(all_examples,
all_nbest_json
[
example
.
qas_id
]
=
nbest_json
with
tf
.
io
.
gfile
.
GFile
(
output_prediction_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
all_predictions
,
indent
=
4
)
+
"
\n
"
)
return
all_predictions
,
all_nbest_json
,
scores_diff_json
with
tf
.
io
.
gfile
.
GFile
(
output_nbest_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
all_nbest_json
,
indent
=
4
)
+
"
\n
"
)
if
version_2_with_negative
:
with
tf
.
io
.
gfile
.
GFile
(
output_null_log_odds
_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
scores_diff_json
,
indent
=
4
)
+
"
\n
"
)
def
write_to_json_files
(
json_records
,
json_file
)
:
with
tf
.
io
.
gfile
.
GFile
(
json
_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
json_records
,
indent
=
4
)
+
"
\n
"
)
def
get_final_text
(
pred_text
,
orig_text
,
do_lower_case
,
verbose
=
False
):
...
...
official/nlp/data/squad_lib_sp.py
View file @
1f3247f4
...
...
@@ -575,10 +575,39 @@ def write_predictions(all_examples,
null_score_diff_threshold
=
0.0
,
verbose
=
False
):
"""Write final predictions to the json file and log-odds of null if needed."""
del
do_lower_case
,
verbose
logging
.
info
(
"Writing predictions to: %s"
,
(
output_prediction_file
))
logging
.
info
(
"Writing nbest to: %s"
,
(
output_nbest_file
))
all_predictions
,
all_nbest_json
,
scores_diff_json
=
(
postprocess_output
(
all_examples
=
all_examples
,
all_features
=
all_features
,
all_results
=
all_results
,
n_best_size
=
n_best_size
,
max_answer_length
=
max_answer_length
,
do_lower_case
=
do_lower_case
,
version_2_with_negative
=
version_2_with_negative
,
null_score_diff_threshold
=
null_score_diff_threshold
,
verbose
=
verbose
))
write_to_json_files
(
all_predictions
,
output_prediction_file
)
write_to_json_files
(
all_nbest_json
,
output_nbest_file
)
if
version_2_with_negative
:
write_to_json_files
(
scores_diff_json
,
output_null_log_odds_file
)
def
postprocess_output
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
max_answer_length
,
do_lower_case
,
version_2_with_negative
=
False
,
null_score_diff_threshold
=
0.0
,
verbose
=
False
):
"""Postprocess model output, to form predicton results."""
del
do_lower_case
,
verbose
example_index_to_features
=
collections
.
defaultdict
(
list
)
for
feature
in
all_features
:
example_index_to_features
[
feature
.
example_index
].
append
(
feature
)
...
...
@@ -740,15 +769,12 @@ def write_predictions(all_examples,
all_nbest_json
[
example
.
qas_id
]
=
nbest_json
with
tf
.
io
.
gfile
.
GFile
(
output_prediction_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
all_predictions
,
indent
=
4
)
+
"
\n
"
)
return
all_predictions
,
all_nbest_json
,
scores_diff_json
with
tf
.
io
.
gfile
.
GFile
(
output_nbest_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
all_nbest_json
,
indent
=
4
)
+
"
\n
"
)
if
version_2_with_negative
:
with
tf
.
io
.
gfile
.
GFile
(
output_null_log_odds
_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
scores_diff_json
,
indent
=
4
)
+
"
\n
"
)
def
write_to_json_files
(
json_records
,
json_file
)
:
with
tf
.
io
.
gfile
.
GFile
(
json
_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
json_records
,
indent
=
4
)
+
"
\n
"
)
def
_get_best_indexes
(
logits
,
n_best_size
):
...
...
official/nlp/optimization.py
View file @
1f3247f4
...
...
@@ -140,19 +140,19 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
def
apply_gradients
(
self
,
grads_and_vars
,
name
=
None
,
all_reduce_sum
_gradients
=
True
):
experimental_aggregate
_gradients
=
True
):
grads
,
tvars
=
list
(
zip
(
*
grads_and_vars
))
if
all_reduce_sum
_gradients
:
# when
all_reduce_sum
_gradients = False, apply_gradients() no
longer
# implicitly allreduce gradients, users manually allreduce gradient
and
# passed the allreduced grads_and_vars. For now, the
clip_by_global_norm
# will be moved to before the explicit allreduce to
keep the math
# the same as TF 1 and pre TF 2.2 implementation.
if
experimental_aggregate
_gradients
:
# when
experimental_aggregate
_gradients = False, apply_gradients() no
#
longer
implicitly allreduce gradients, users manually allreduce gradient
#
and
passed the allreduced grads_and_vars. For now, the
#
clip_by_global_norm
will be moved to before the explicit allreduce to
#
keep the math
the same as TF 1 and pre TF 2.2 implementation.
(
grads
,
_
)
=
tf
.
clip_by_global_norm
(
grads
,
clip_norm
=
1.0
)
return
super
(
AdamWeightDecay
,
self
).
apply_gradients
(
zip
(
grads
,
tvars
),
name
=
name
,
all_reduce_sum_gradients
=
all_reduce_sum
_gradients
)
experimental_aggregate_gradients
=
experimental_aggregate
_gradients
)
def
_get_lr
(
self
,
var_device
,
var_dtype
,
apply_state
):
"""Retrieves the learning rate with the given state."""
...
...
official/nlp/transformer/misc.py
View file @
1f3247f4
...
...
@@ -239,7 +239,10 @@ def get_callbacks(steps_per_epoch):
"""Returns common callbacks."""
callbacks
=
[]
if
FLAGS
.
enable_time_history
:
time_callback
=
keras_utils
.
TimeHistory
(
FLAGS
.
batch_size
,
FLAGS
.
log_steps
)
time_callback
=
keras_utils
.
TimeHistory
(
FLAGS
.
batch_size
,
FLAGS
.
log_steps
,
FLAGS
.
model_dir
if
FLAGS
.
enable_tensorboard
else
None
)
callbacks
.
append
(
time_callback
)
if
FLAGS
.
enable_tensorboard
:
...
...
official/nlp/transformer/transformer_main.py
View file @
1f3247f4
...
...
@@ -246,6 +246,11 @@ class TransformerTask(object):
callbacks
=
self
.
_create_callbacks
(
flags_obj
.
model_dir
,
0
,
params
)
# Only TimeHistory callback is supported for CTL
if
params
[
"use_ctl"
]:
callbacks
=
[
cb
for
cb
in
callbacks
if
isinstance
(
cb
,
keras_utils
.
TimeHistory
)]
# TODO(b/139418525): Refactor the custom training loop logic.
@
tf
.
function
def
train_steps
(
iterator
,
steps
):
...
...
@@ -299,8 +304,13 @@ class TransformerTask(object):
if
not
self
.
use_tpu
:
raise
NotImplementedError
(
"Custom training loop on GPUs is not implemented."
)
# Runs training steps.
with
summary_writer
.
as_default
():
for
cb
in
callbacks
:
cb
.
on_epoch_begin
(
current_iteration
)
cb
.
on_batch_begin
(
0
)
train_steps
(
train_ds_iterator
,
tf
.
convert_to_tensor
(
train_steps_per_eval
,
dtype
=
tf
.
int32
))
...
...
@@ -309,10 +319,18 @@ class TransformerTask(object):
logging
.
info
(
"Train Step: %d/%d / loss = %s"
,
current_step
,
flags_obj
.
train_steps
,
train_loss
)
for
cb
in
callbacks
:
cb
.
on_batch_end
(
train_steps_per_eval
-
1
)
cb
.
on_epoch_end
(
current_iteration
)
if
params
[
"enable_tensorboard"
]:
for
metric_obj
in
train_metrics
:
tf
.
compat
.
v2
.
summary
.
scalar
(
metric_obj
.
name
,
metric_obj
.
result
(),
current_step
)
summary_writer
.
flush
()
for
cb
in
callbacks
:
cb
.
on_train_end
()
if
flags_obj
.
enable_checkpointing
:
# avoid check-pointing when running for benchmarking.
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment