Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0265f59c
Commit
0265f59c
authored
Mar 26, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 303225001
parent
1bd89dac
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
278 additions
and
26 deletions
+278
-26
official/nlp/bert/run_squad.py
official/nlp/bert/run_squad.py
+17
-18
official/nlp/bert/run_squad_helper.py
official/nlp/bert/run_squad_helper.py
+8
-7
official/nlp/bert/squad_evaluate_v1_1.py
official/nlp/bert/squad_evaluate_v1_1.py
+1
-1
official/nlp/bert/squad_evaluate_v2_0.py
official/nlp/bert/squad_evaluate_v2_0.py
+252
-0
No files found.
official/nlp/bert/run_squad.py
View file @
0265f59c
...
@@ -19,9 +19,10 @@ from __future__ import division
...
@@ -19,9 +19,10 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
json
import
json
import
os
import
os
import
tempfile
import
tempfile
import
time
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl
import
logging
...
@@ -126,24 +127,22 @@ def main(_):
...
@@ -126,24 +127,22 @@ def main(_):
if
'predict'
in
FLAGS
.
mode
:
if
'predict'
in
FLAGS
.
mode
:
predict_squad
(
strategy
,
input_meta_data
)
predict_squad
(
strategy
,
input_meta_data
)
if
'eval'
in
FLAGS
.
mode
:
if
'eval'
in
FLAGS
.
mode
:
if
input_meta_data
.
get
(
'version_2_with_negative'
,
False
):
eval_metrics
=
eval_squad
(
strategy
,
input_meta_data
)
logging
.
error
(
'SQuAD v2 eval is not supported. '
f1_score
=
eval_metrics
[
'final_f1'
]
'Falling back to predict mode.'
)
logging
.
info
(
'SQuAD eval F1-score: %f'
,
f1_score
)
predict_squad
(
strategy
,
input_meta_data
)
if
(
not
strategy
)
or
strategy
.
extended
.
should_save_summary
:
summary_dir
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'summaries'
)
else
:
else
:
eval_metrics
=
eval_squad
(
strategy
,
input_meta_data
)
summary_dir
=
tempfile
.
mkdtemp
()
f1_score
=
eval_metrics
[
'f1'
]
summary_writer
=
tf
.
summary
.
create_file_writer
(
logging
.
info
(
'SQuAD eval F1-score: %f'
,
f1_score
)
os
.
path
.
join
(
summary_dir
,
'eval'
))
if
(
not
strategy
)
or
strategy
.
extended
.
should_save_summary
:
with
summary_writer
.
as_default
():
summary_dir
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'summaries'
)
# TODO(lehou): write to the correct step number.
else
:
tf
.
summary
.
scalar
(
'F1-score'
,
f1_score
,
step
=
0
)
summary_dir
=
tempfile
.
mkdtemp
()
summary_writer
.
flush
()
summary_writer
=
tf
.
summary
.
create_file_writer
(
# Wait for some time, for the depending mldash/tensorboard jobs to finish
os
.
path
.
join
(
summary_dir
,
'eval'
))
# exporting the final F1-score.
with
summary_writer
.
as_default
():
time
.
sleep
(
60
)
# TODO(lehou): write to the correct step number.
tf
.
summary
.
scalar
(
'F1-score'
,
f1_score
,
step
=
0
)
summary_writer
.
flush
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
official/nlp/bert/run_squad_helper.py
View file @
0265f59c
...
@@ -31,6 +31,7 @@ from official.nlp.bert import input_pipeline
...
@@ -31,6 +31,7 @@ from official.nlp.bert import input_pipeline
from
official.nlp.bert
import
model_saving_utils
from
official.nlp.bert
import
model_saving_utils
from
official.nlp.bert
import
model_training_utils
from
official.nlp.bert
import
model_training_utils
from
official.nlp.bert
import
squad_evaluate_v1_1
from
official.nlp.bert
import
squad_evaluate_v1_1
from
official.nlp.bert
import
squad_evaluate_v2_0
from
official.nlp.data
import
squad_lib_sp
from
official.nlp.data
import
squad_lib_sp
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
...
@@ -373,16 +374,16 @@ def eval_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib):
...
@@ -373,16 +374,16 @@ def eval_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib):
dump_to_files
(
all_predictions
,
all_nbest_json
,
scores_diff_json
,
squad_lib
,
dump_to_files
(
all_predictions
,
all_nbest_json
,
scores_diff_json
,
squad_lib
,
input_meta_data
.
get
(
'version_2_with_negative'
,
False
))
input_meta_data
.
get
(
'version_2_with_negative'
,
False
))
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
predict_file
,
'r'
)
as
reader
:
dataset_json
=
json
.
load
(
reader
)
pred_dataset
=
dataset_json
[
'data'
]
if
input_meta_data
.
get
(
'version_2_with_negative'
,
False
):
if
input_meta_data
.
get
(
'version_2_with_negative'
,
False
):
# TODO(lehou): support in memory evaluation for SQuAD v2.
eval_metrics
=
squad_evaluate_v2_0
.
evaluate
(
pred_dataset
,
logging
.
error
(
'SQuAD v2 eval is not supported. Skipping eval'
)
all_predictions
,
return
N
on
e
scores_diff_js
on
)
else
:
else
:
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
predict_file
,
'r'
)
as
reader
:
dataset_json
=
json
.
load
(
reader
)
pred_dataset
=
dataset_json
[
'data'
]
eval_metrics
=
squad_evaluate_v1_1
.
evaluate
(
pred_dataset
,
all_predictions
)
eval_metrics
=
squad_evaluate_v1_1
.
evaluate
(
pred_dataset
,
all_predictions
)
return
eval_metrics
return
eval_metrics
def
export_squad
(
model_export_path
,
input_meta_data
,
bert_config
):
def
export_squad
(
model_export_path
,
input_meta_data
,
bert_config
):
...
...
official/nlp/bert/squad_evaluate_v1_1.py
View file @
0265f59c
...
@@ -105,4 +105,4 @@ def evaluate(dataset, predictions):
...
@@ -105,4 +105,4 @@ def evaluate(dataset, predictions):
exact_match
=
exact_match
/
total
exact_match
=
exact_match
/
total
f1
=
f1
/
total
f1
=
f1
/
total
return
{
"exact_match"
:
exact_match
,
"f1"
:
f1
}
return
{
"exact_match"
:
exact_match
,
"
final_
f1"
:
f1
}
official/nlp/bert/squad_evaluate_v2_0.py
0 → 100644
View file @
0265f59c
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation script for SQuAD version 2.0.
The functions are copied and modified from
https://raw.githubusercontent.com/white127/SQUAD-2.0-bidaf/master/evaluate-v2.0.py
In addition to basic functionality, we also compute additional statistics and
plot precision-recall curves if an additional na_prob.json file is provided.
This file is expected to map question ID's to the model's predicted probability
that a question is unanswerable.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
re
import
string
from
absl
import
logging
def
_make_qid_to_has_ans
(
dataset
):
qid_to_has_ans
=
{}
for
article
in
dataset
:
for
p
in
article
[
'paragraphs'
]:
for
qa
in
p
[
'qas'
]:
qid_to_has_ans
[
qa
[
'id'
]]
=
bool
(
qa
[
'answers'
])
return
qid_to_has_ans
def
_normalize_answer
(
s
):
"""Lower text and remove punctuation, articles and extra whitespace."""
def
remove_articles
(
text
):
regex
=
re
.
compile
(
r
'\b(a|an|the)\b'
,
re
.
UNICODE
)
return
re
.
sub
(
regex
,
' '
,
text
)
def
white_space_fix
(
text
):
return
' '
.
join
(
text
.
split
())
def
remove_punc
(
text
):
exclude
=
set
(
string
.
punctuation
)
return
''
.
join
(
ch
for
ch
in
text
if
ch
not
in
exclude
)
def
lower
(
text
):
return
text
.
lower
()
return
white_space_fix
(
remove_articles
(
remove_punc
(
lower
(
s
))))
def
_get_tokens
(
s
):
if
not
s
:
return
[]
return
_normalize_answer
(
s
).
split
()
def
_compute_exact
(
a_gold
,
a_pred
):
return
int
(
_normalize_answer
(
a_gold
)
==
_normalize_answer
(
a_pred
))
def
_compute_f1
(
a_gold
,
a_pred
):
"""Compute F1-score."""
gold_toks
=
_get_tokens
(
a_gold
)
pred_toks
=
_get_tokens
(
a_pred
)
common
=
collections
.
Counter
(
gold_toks
)
&
collections
.
Counter
(
pred_toks
)
num_same
=
sum
(
common
.
values
())
if
not
gold_toks
or
not
pred_toks
:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return
int
(
gold_toks
==
pred_toks
)
if
num_same
==
0
:
return
0
precision
=
1.0
*
num_same
/
len
(
pred_toks
)
recall
=
1.0
*
num_same
/
len
(
gold_toks
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
return
f1
def
_get_raw_scores
(
dataset
,
predictions
):
"""Compute raw scores."""
exact_scores
=
{}
f1_scores
=
{}
for
article
in
dataset
:
for
p
in
article
[
'paragraphs'
]:
for
qa
in
p
[
'qas'
]:
qid
=
qa
[
'id'
]
gold_answers
=
[
a
[
'text'
]
for
a
in
qa
[
'answers'
]
if
_normalize_answer
(
a
[
'text'
])]
if
not
gold_answers
:
# For unanswerable questions, only correct answer is empty string
gold_answers
=
[
''
]
if
qid
not
in
predictions
:
logging
.
error
(
'Missing prediction for %s'
,
qid
)
continue
a_pred
=
predictions
[
qid
]
# Take max over all gold answers
exact_scores
[
qid
]
=
max
(
_compute_exact
(
a
,
a_pred
)
for
a
in
gold_answers
)
f1_scores
[
qid
]
=
max
(
_compute_f1
(
a
,
a_pred
)
for
a
in
gold_answers
)
return
exact_scores
,
f1_scores
def
_apply_no_ans_threshold
(
scores
,
na_probs
,
qid_to_has_ans
,
na_prob_thresh
=
1.0
):
new_scores
=
{}
for
qid
,
s
in
scores
.
items
():
pred_na
=
na_probs
[
qid
]
>
na_prob_thresh
if
pred_na
:
new_scores
[
qid
]
=
float
(
not
qid_to_has_ans
[
qid
])
else
:
new_scores
[
qid
]
=
s
return
new_scores
def
_make_eval_dict
(
exact_scores
,
f1_scores
,
qid_list
=
None
):
"""Make evaluation result dictionary."""
if
not
qid_list
:
total
=
len
(
exact_scores
)
return
collections
.
OrderedDict
([
(
'exact'
,
100.0
*
sum
(
exact_scores
.
values
())
/
total
),
(
'f1'
,
100.0
*
sum
(
f1_scores
.
values
())
/
total
),
(
'total'
,
total
),
])
else
:
total
=
len
(
qid_list
)
return
collections
.
OrderedDict
([
(
'exact'
,
100.0
*
sum
(
exact_scores
[
k
]
for
k
in
qid_list
)
/
total
),
(
'f1'
,
100.0
*
sum
(
f1_scores
[
k
]
for
k
in
qid_list
)
/
total
),
(
'total'
,
total
),
])
def
_merge_eval
(
main_eval
,
new_eval
,
prefix
):
for
k
in
new_eval
:
main_eval
[
'%s_%s'
%
(
prefix
,
k
)]
=
new_eval
[
k
]
def
_make_precision_recall_eval
(
scores
,
na_probs
,
num_true_pos
,
qid_to_has_ans
):
"""Make evaluation dictionary containing average recision recall."""
qid_list
=
sorted
(
na_probs
,
key
=
lambda
k
:
na_probs
[
k
])
true_pos
=
0.0
cur_p
=
1.0
cur_r
=
0.0
precisions
=
[
1.0
]
recalls
=
[
0.0
]
avg_prec
=
0.0
for
i
,
qid
in
enumerate
(
qid_list
):
if
qid_to_has_ans
[
qid
]:
true_pos
+=
scores
[
qid
]
cur_p
=
true_pos
/
float
(
i
+
1
)
cur_r
=
true_pos
/
float
(
num_true_pos
)
if
i
==
len
(
qid_list
)
-
1
or
na_probs
[
qid
]
!=
na_probs
[
qid_list
[
i
+
1
]]:
# i.e., if we can put a threshold after this point
avg_prec
+=
cur_p
*
(
cur_r
-
recalls
[
-
1
])
precisions
.
append
(
cur_p
)
recalls
.
append
(
cur_r
)
return
{
'ap'
:
100.0
*
avg_prec
}
def
_run_precision_recall_analysis
(
main_eval
,
exact_raw
,
f1_raw
,
na_probs
,
qid_to_has_ans
):
"""Run precision recall analysis and return result dictionary."""
num_true_pos
=
sum
(
1
for
v
in
qid_to_has_ans
.
values
()
if
v
)
if
num_true_pos
==
0
:
return
pr_exact
=
_make_precision_recall_eval
(
exact_raw
,
na_probs
,
num_true_pos
,
qid_to_has_ans
)
pr_f1
=
_make_precision_recall_eval
(
f1_raw
,
na_probs
,
num_true_pos
,
qid_to_has_ans
)
oracle_scores
=
{
k
:
float
(
v
)
for
k
,
v
in
qid_to_has_ans
.
items
()}
pr_oracle
=
_make_precision_recall_eval
(
oracle_scores
,
na_probs
,
num_true_pos
,
qid_to_has_ans
)
_merge_eval
(
main_eval
,
pr_exact
,
'pr_exact'
)
_merge_eval
(
main_eval
,
pr_f1
,
'pr_f1'
)
_merge_eval
(
main_eval
,
pr_oracle
,
'pr_oracle'
)
def
_find_best_thresh
(
predictions
,
scores
,
na_probs
,
qid_to_has_ans
):
"""Find the best threshold for no answer probability."""
num_no_ans
=
sum
(
1
for
k
in
qid_to_has_ans
if
not
qid_to_has_ans
[
k
])
cur_score
=
num_no_ans
best_score
=
cur_score
best_thresh
=
0.0
qid_list
=
sorted
(
na_probs
,
key
=
lambda
k
:
na_probs
[
k
])
for
qid
in
qid_list
:
if
qid
not
in
scores
:
continue
if
qid_to_has_ans
[
qid
]:
diff
=
scores
[
qid
]
else
:
if
predictions
[
qid
]:
diff
=
-
1
else
:
diff
=
0
cur_score
+=
diff
if
cur_score
>
best_score
:
best_score
=
cur_score
best_thresh
=
na_probs
[
qid
]
return
100.0
*
best_score
/
len
(
scores
),
best_thresh
def
_find_all_best_thresh
(
main_eval
,
predictions
,
exact_raw
,
f1_raw
,
na_probs
,
qid_to_has_ans
):
best_exact
,
exact_thresh
=
_find_best_thresh
(
predictions
,
exact_raw
,
na_probs
,
qid_to_has_ans
)
best_f1
,
f1_thresh
=
_find_best_thresh
(
predictions
,
f1_raw
,
na_probs
,
qid_to_has_ans
)
main_eval
[
'final_exact'
]
=
best_exact
main_eval
[
'final_exact_thresh'
]
=
exact_thresh
main_eval
[
'final_f1'
]
=
best_f1
main_eval
[
'final_f1_thresh'
]
=
f1_thresh
def
evaluate
(
dataset
,
predictions
,
na_probs
=
None
):
"""Evaluate prediction results."""
new_orig_data
=
[]
for
article
in
dataset
:
for
p
in
article
[
'paragraphs'
]:
for
qa
in
p
[
'qas'
]:
if
qa
[
'id'
]
in
predictions
:
new_para
=
{
'qas'
:
[
qa
]}
new_article
=
{
'paragraphs'
:
[
new_para
]}
new_orig_data
.
append
(
new_article
)
dataset
=
new_orig_data
if
na_probs
is
None
:
na_probs
=
{
k
:
0.0
for
k
in
predictions
}
qid_to_has_ans
=
_make_qid_to_has_ans
(
dataset
)
# maps qid to True/False
has_ans_qids
=
[
k
for
k
,
v
in
qid_to_has_ans
.
items
()
if
v
]
no_ans_qids
=
[
k
for
k
,
v
in
qid_to_has_ans
.
items
()
if
not
v
]
exact_raw
,
f1_raw
=
_get_raw_scores
(
dataset
,
predictions
)
exact_thresh
=
_apply_no_ans_threshold
(
exact_raw
,
na_probs
,
qid_to_has_ans
)
f1_thresh
=
_apply_no_ans_threshold
(
f1_raw
,
na_probs
,
qid_to_has_ans
)
out_eval
=
_make_eval_dict
(
exact_thresh
,
f1_thresh
)
if
has_ans_qids
:
has_ans_eval
=
_make_eval_dict
(
exact_thresh
,
f1_thresh
,
qid_list
=
has_ans_qids
)
_merge_eval
(
out_eval
,
has_ans_eval
,
'HasAns'
)
if
no_ans_qids
:
no_ans_eval
=
_make_eval_dict
(
exact_thresh
,
f1_thresh
,
qid_list
=
no_ans_qids
)
_merge_eval
(
out_eval
,
no_ans_eval
,
'NoAns'
)
_find_all_best_thresh
(
out_eval
,
predictions
,
exact_raw
,
f1_raw
,
na_probs
,
qid_to_has_ans
)
_run_precision_recall_analysis
(
out_eval
,
exact_raw
,
f1_raw
,
na_probs
,
qid_to_has_ans
)
return
out_eval
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment