Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
de276de1
Commit
de276de1
authored
Dec 03, 2019
by
LysandreJik
Browse files
Working evaluation
parent
c835bc85
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
505 additions
and
141 deletions
+505
-141
examples/run_squad.py
examples/run_squad.py
+18
-25
transformers/data/metrics/squad_metrics.py
transformers/data/metrics/squad_metrics.py
+476
-108
transformers/data/processors/squad.py
transformers/data/processors/squad.py
+11
-8
No files found.
examples/run_squad.py
View file @
de276de1
...
...
@@ -16,7 +16,8 @@
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
from
__future__
import
absolute_import
,
division
,
print_function
from
transformers.data.processors.squad
import
SquadV1Processor
,
SquadV2Processor
from
transformers.data.processors.squad
import
SquadV1Processor
,
SquadV2Processor
,
SquadResult
from
transformers.data.metrics.squad_metrics
import
compute_predictions
,
compute_predictions_extended
,
squad_evaluate
import
argparse
import
logging
...
...
@@ -230,9 +231,11 @@ def evaluate(args, model, tokenizer, prefix=""):
model
.
eval
()
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
with
torch
.
no_grad
():
inputs
=
{
'input_ids'
:
batch
[
0
],
'attention_mask'
:
batch
[
1
]
}
inputs
=
{
'input_ids'
:
batch
[
0
],
'attention_mask'
:
batch
[
1
]
}
if
args
.
model_type
!=
'distilbert'
:
inputs
[
'token_type_ids'
]
=
None
if
args
.
model_type
==
'xlm'
else
batch
[
2
]
# XLM don't use segment_ids
example_indices
=
batch
[
3
]
...
...
@@ -244,18 +247,8 @@ def evaluate(args, model, tokenizer, prefix=""):
for
i
,
example_index
in
enumerate
(
example_indices
):
eval_feature
=
features
[
example_index
.
item
()]
unique_id
=
int
(
eval_feature
.
unique_id
)
if
args
.
model_type
in
[
'xlnet'
,
'xlm'
]:
# XLNet uses a more complex post-processing procedure
result
=
RawResultExtended
(
unique_id
=
unique_id
,
start_top_log_probs
=
to_list
(
outputs
[
0
][
i
]),
start_top_index
=
to_list
(
outputs
[
1
][
i
]),
end_top_log_probs
=
to_list
(
outputs
[
2
][
i
]),
end_top_index
=
to_list
(
outputs
[
3
][
i
]),
cls_logits
=
to_list
(
outputs
[
4
][
i
]))
else
:
result
=
RawResult
(
unique_id
=
unique_id
,
start_logits
=
to_list
(
outputs
[
0
][
i
]),
end_logits
=
to_list
(
outputs
[
1
][
i
]))
result
=
SquadResult
([
to_list
(
output
[
i
])
for
output
in
outputs
]
+
[
unique_id
])
all_results
.
append
(
result
)
evalTime
=
timeit
.
default_timer
()
-
start_time
...
...
@@ -271,22 +264,18 @@ def evaluate(args, model, tokenizer, prefix=""):
if
args
.
model_type
in
[
'xlnet'
,
'xlm'
]:
# XLNet uses a more complex post-processing procedure
wri
te_predictions_extended
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
predictions
=
compu
te_predictions_extended
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
args
.
max_answer_length
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
args
.
predict_file
,
model
.
config
.
start_n_top
,
model
.
config
.
end_n_top
,
args
.
version_2_with_negative
,
tokenizer
,
args
.
verbose_logging
)
else
:
wri
te_predictions
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
predictions
=
compu
te_predictions
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
args
.
max_answer_length
,
args
.
do_lower_case
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
args
.
verbose_logging
,
args
.
version_2_with_negative
,
args
.
null_score_diff_threshold
)
# Evaluate with the official SQuAD script
evaluate_options
=
EVAL_OPTS
(
data_file
=
args
.
predict_file
,
pred_file
=
output_prediction_file
,
na_prob_file
=
output_null_log_odds_file
)
results
=
evaluate_on_squad
(
evaluate_options
)
results
=
squad_evaluate
(
examples
,
predictions
)
return
results
def
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
):
...
...
@@ -306,8 +295,12 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
logger
.
info
(
"Creating features from dataset file at %s"
,
input_file
)
processor
=
SquadV2Processor
()
examples
=
processor
.
get_dev_examples
(
"examples/squad"
)
if
evaluate
else
processor
.
get_train_examples
(
"examples/squad"
)
features
=
squad_convert_examples_to_features
(
examples
=
processor
.
get_dev_examples
(
"examples/squad"
,
only_first
=
100
)
if
evaluate
else
processor
.
get_train_examples
(
"examples/squad"
)
# import tensorflow_datasets as tfds
# tfds_examples = tfds.load("squad")
# examples = SquadV1Processor().get_examples_from_dataset(tfds_examples["validation"])
features
=
squad_convert_examples_to_features
(
examples
=
examples
,
tokenizer
=
tokenizer
,
max_seq_length
=
args
.
max_seq_length
,
...
...
transformers/data/metrics/squad_metrics.py
View file @
de276de1
""" Very heavily inspired by the official evaluation script for SQuAD version 2.0 which was
modified by XLNet authors to update `find_best_threshold` scripts for SQuAD V2.0
In addition to basic functionality, we also compute additional statistics and
plot precision-recall curves if an additional na_prob.json file is provided.
This file is expected to map question ID's to the model's predicted probability
that a question is unanswerable.
"""
import
json
import
logging
import
math
import
collections
from
io
import
open
from
tqdm
import
tqdm
import
string
import
re
from
transformers.tokenization_bert
import
BasicTokenizer
,
whitespace_tokenize
logger
=
logging
.
getLogger
(
__name__
)
def
normalize_answer
(
s
):
"""Lower text and remove punctuation, articles and extra whitespace."""
def
remove_articles
(
text
):
regex
=
re
.
compile
(
r
'\b(a|an|the)\b'
,
re
.
UNICODE
)
return
re
.
sub
(
regex
,
' '
,
text
)
def
white_space_fix
(
text
):
return
' '
.
join
(
text
.
split
())
def
remove_punc
(
text
):
exclude
=
set
(
string
.
punctuation
)
return
''
.
join
(
ch
for
ch
in
text
if
ch
not
in
exclude
)
def
lower
(
text
):
return
text
.
lower
()
return
white_space_fix
(
remove_articles
(
remove_punc
(
lower
(
s
))))
def
get_tokens
(
s
):
if
not
s
:
return
[]
return
normalize_answer
(
s
).
split
()
def
compute_exact
(
a_gold
,
a_pred
):
return
int
(
normalize_answer
(
a_gold
)
==
normalize_answer
(
a_pred
))
def
compute_f1
(
a_gold
,
a_pred
):
gold_toks
=
get_tokens
(
a_gold
)
pred_toks
=
get_tokens
(
a_pred
)
common
=
collections
.
Counter
(
gold_toks
)
&
collections
.
Counter
(
pred_toks
)
num_same
=
sum
(
common
.
values
())
if
len
(
gold_toks
)
==
0
or
len
(
pred_toks
)
==
0
:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return
int
(
gold_toks
==
pred_toks
)
if
num_same
==
0
:
return
0
precision
=
1.0
*
num_same
/
len
(
pred_toks
)
recall
=
1.0
*
num_same
/
len
(
gold_toks
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
return
f1
def
get_raw_scores
(
examples
,
preds
):
"""
Computes the exact and f1 scores from the examples and the model predictions
"""
exact_scores
=
{}
f1_scores
=
{}
for
example
in
examples
:
qas_id
=
example
.
qas_id
gold_answers
=
[
answer
[
'text'
]
for
answer
in
example
.
answers
if
normalize_answer
(
answer
[
'text'
])]
if
not
gold_answers
:
# For unanswerable questions, only correct answer is empty string
gold_answers
=
[
''
]
if
qas_id
not
in
preds
:
print
(
'Missing prediction for %s'
%
qas_id
)
continue
prediction
=
preds
[
qas_id
]
exact_scores
[
qas_id
]
=
max
(
compute_exact
(
a
,
prediction
)
for
a
in
gold_answers
)
f1_scores
[
qas_id
]
=
max
(
compute_f1
(
a
,
prediction
)
for
a
in
gold_answers
)
return
exact_scores
,
f1_scores
def
apply_no_ans_threshold
(
scores
,
na_probs
,
qid_to_has_ans
,
na_prob_thresh
):
new_scores
=
{}
for
qid
,
s
in
scores
.
items
():
pred_na
=
na_probs
[
qid
]
>
na_prob_thresh
if
pred_na
:
new_scores
[
qid
]
=
float
(
not
qid_to_has_ans
[
qid
])
else
:
new_scores
[
qid
]
=
s
return
new_scores
def
make_eval_dict
(
exact_scores
,
f1_scores
,
qid_list
=
None
):
if
not
qid_list
:
total
=
len
(
exact_scores
)
return
collections
.
OrderedDict
([
(
'exact'
,
100.0
*
sum
(
exact_scores
.
values
())
/
total
),
(
'f1'
,
100.0
*
sum
(
f1_scores
.
values
())
/
total
),
(
'total'
,
total
),
])
else
:
total
=
len
(
qid_list
)
return
collections
.
OrderedDict
([
(
'exact'
,
100.0
*
sum
(
exact_scores
[
k
]
for
k
in
qid_list
)
/
total
),
(
'f1'
,
100.0
*
sum
(
f1_scores
[
k
]
for
k
in
qid_list
)
/
total
),
(
'total'
,
total
),
])
def
merge_eval
(
main_eval
,
new_eval
,
prefix
):
for
k
in
new_eval
:
main_eval
[
'%s_%s'
%
(
prefix
,
k
)]
=
new_eval
[
k
]
def
find_best_thresh
(
preds
,
scores
,
na_probs
,
qid_to_has_ans
):
num_no_ans
=
sum
(
1
for
k
in
qid_to_has_ans
if
not
qid_to_has_ans
[
k
])
cur_score
=
num_no_ans
best_score
=
cur_score
best_thresh
=
0.0
qid_list
=
sorted
(
na_probs
,
key
=
lambda
k
:
na_probs
[
k
])
for
_
,
qid
in
enumerate
(
qid_list
):
if
qid
not
in
scores
:
continue
if
qid_to_has_ans
[
qid
]:
diff
=
scores
[
qid
]
else
:
if
preds
[
qid
]:
diff
=
-
1
else
:
diff
=
0
cur_score
+=
diff
if
cur_score
>
best_score
:
best_score
=
cur_score
best_thresh
=
na_probs
[
qid
]
return
100.0
*
best_score
/
len
(
scores
),
best_thresh
def
find_all_best_thresh
(
main_eval
,
preds
,
exact_raw
,
f1_raw
,
na_probs
,
qid_to_has_ans
):
best_exact
,
exact_thresh
=
find_best_thresh
(
preds
,
exact_raw
,
na_probs
,
qid_to_has_ans
)
best_f1
,
f1_thresh
=
find_best_thresh
(
preds
,
f1_raw
,
na_probs
,
qid_to_has_ans
)
main_eval
[
'best_exact'
]
=
best_exact
main_eval
[
'best_exact_thresh'
]
=
exact_thresh
main_eval
[
'best_f1'
]
=
best_f1
main_eval
[
'best_f1_thresh'
]
=
f1_thresh
def
squad_evaluate
(
examples
,
preds
,
no_answer_probs
=
None
,
no_answer_probability_threshold
=
1.0
):
qas_id_to_has_answer
=
{
example
.
qas_id
:
bool
(
example
.
answers
)
for
example
in
examples
}
has_answer_qids
=
[
qas_id
for
qas_id
,
has_answer
in
qas_id_to_has_answer
.
items
()
if
has_answer
]
no_answer_qids
=
[
qas_id
for
qas_id
,
has_answer
in
qas_id_to_has_answer
.
items
()
if
not
has_answer
]
if
no_answer_probs
is
None
:
no_answer_probs
=
{
k
:
0.0
for
k
in
preds
}
exact
,
f1
=
get_raw_scores
(
examples
,
preds
)
exact_threshold
=
apply_no_ans_threshold
(
exact
,
no_answer_probs
,
qas_id_to_has_answer
,
no_answer_probability_threshold
)
f1_threshold
=
apply_no_ans_threshold
(
f1
,
no_answer_probs
,
qas_id_to_has_answer
,
no_answer_probability_threshold
)
evaluation
=
make_eval_dict
(
exact_threshold
,
f1_threshold
)
if
has_answer_qids
:
has_ans_eval
=
make_eval_dict
(
exact_threshold
,
f1_threshold
,
qid_list
=
has_answer_qids
)
merge_eval
(
evaluation
,
has_ans_eval
,
'HasAns'
)
if
no_answer_qids
:
no_ans_eval
=
make_eval_dict
(
exact_threshold
,
f1_threshold
,
qid_list
=
no_answer_qids
)
merge_eval
(
evaluation
,
no_ans_eval
,
'NoAns'
)
if
no_answer_probs
:
find_all_best_thresh
(
evaluation
,
preds
,
exact
,
f1
,
no_answer_probs
,
qas_id_to_has_answer
)
return
evaluation
def
get_final_text
(
pred_text
,
orig_text
,
do_lower_case
,
verbose_logging
=
False
):
"""Project the tokenized prediction back to the original text."""
# When we created the data, we kept track of the alignment between original
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
# now `orig_text` contains the span of our original text corresponding to the
# span that we predicted.
#
# However, `orig_text` may contain extra characters that we don't want in
# our prediction.
#
# For example, let's say:
# pred_text = steve smith
# orig_text = Steve Smith's
#
# We don't want to return `orig_text` because it contains the extra "'s".
#
# We don't want to return `pred_text` because it's already been normalized
# (the SQuAD eval script also does punctuation stripping/lower casing but
# our tokenizer does additional normalization like stripping accent
# characters).
#
# What we really want to return is "Steve Smith".
#
# Therefore, we have to apply a semi-complicated alignment heuristic between
# `pred_text` and `orig_text` to get a character-to-character alignment. This
# can fail in certain cases in which case we just return `orig_text`.
def
_strip_spaces
(
text
):
ns_chars
=
[]
ns_to_s_map
=
collections
.
OrderedDict
()
for
(
i
,
c
)
in
enumerate
(
text
):
if
c
==
" "
:
continue
ns_to_s_map
[
len
(
ns_chars
)]
=
i
ns_chars
.
append
(
c
)
ns_text
=
""
.
join
(
ns_chars
)
return
(
ns_text
,
ns_to_s_map
)
# We first tokenize `orig_text`, strip whitespace from the result
# and `pred_text`, and check if they are the same length. If they are
# NOT the same length, the heuristic has failed. If they are the same
# length, we assume the characters are one-to-one aligned.
tokenizer
=
BasicTokenizer
(
do_lower_case
=
do_lower_case
)
tok_text
=
" "
.
join
(
tokenizer
.
tokenize
(
orig_text
))
start_position
=
tok_text
.
find
(
pred_text
)
if
start_position
==
-
1
:
if
verbose_logging
:
logger
.
info
(
"Unable to find text: '%s' in '%s'"
%
(
pred_text
,
orig_text
))
return
orig_text
end_position
=
start_position
+
len
(
pred_text
)
-
1
(
orig_ns_text
,
orig_ns_to_s_map
)
=
_strip_spaces
(
orig_text
)
(
tok_ns_text
,
tok_ns_to_s_map
)
=
_strip_spaces
(
tok_text
)
if
len
(
orig_ns_text
)
!=
len
(
tok_ns_text
):
if
verbose_logging
:
logger
.
info
(
"Length not equal after stripping spaces: '%s' vs '%s'"
,
orig_ns_text
,
tok_ns_text
)
return
orig_text
# We then project the characters in `pred_text` back to `orig_text` using
# the character-to-character alignment.
tok_s_to_ns_map
=
{}
for
(
i
,
tok_index
)
in
tok_ns_to_s_map
.
items
():
tok_s_to_ns_map
[
tok_index
]
=
i
orig_start_position
=
None
if
start_position
in
tok_s_to_ns_map
:
ns_start_position
=
tok_s_to_ns_map
[
start_position
]
if
ns_start_position
in
orig_ns_to_s_map
:
orig_start_position
=
orig_ns_to_s_map
[
ns_start_position
]
if
orig_start_position
is
None
:
if
verbose_logging
:
logger
.
info
(
"Couldn't map start position"
)
return
orig_text
orig_end_position
=
None
if
end_position
in
tok_s_to_ns_map
:
ns_end_position
=
tok_s_to_ns_map
[
end_position
]
if
ns_end_position
in
orig_ns_to_s_map
:
orig_end_position
=
orig_ns_to_s_map
[
ns_end_position
]
if
orig_end_position
is
None
:
if
verbose_logging
:
logger
.
info
(
"Couldn't map end position"
)
return
orig_text
output_text
=
orig_text
[
orig_start_position
:(
orig_end_position
+
1
)]
return
output_text
def
_get_best_indexes
(
logits
,
n_best_size
):
"""Get the n-best logits from a list."""
index_and_score
=
sorted
(
enumerate
(
logits
),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
best_indexes
=
[]
for
i
in
range
(
len
(
index_and_score
)):
if
i
>=
n_best_size
:
break
best_indexes
.
append
(
index_and_score
[
i
][
0
])
return
best_indexes
def
_compute_softmax
(
scores
):
"""Compute softmax probability over raw logits."""
if
not
scores
:
return
[]
max_score
=
None
for
score
in
scores
:
if
max_score
is
None
or
score
>
max_score
:
max_score
=
score
exp_scores
=
[]
total_sum
=
0.0
for
score
in
scores
:
x
=
math
.
exp
(
score
-
max_score
)
exp_scores
.
append
(
x
)
total_sum
+=
x
probs
=
[]
for
score
in
exp_scores
:
probs
.
append
(
score
/
total_sum
)
return
probs
def
compute_predictions
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
max_answer_length
,
do_lower_case
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
verbose_logging
,
...
...
@@ -204,132 +512,192 @@ def compute_predictions(all_examples, all_features, all_results, n_best_size,
return
all_predictions
def
get_final_text
(
pred_text
,
orig_text
,
do_lower_case
,
verbose_logging
=
False
):
"""Project the tokenized prediction back to the original text."""
def
compute_predictions_extended
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
max_answer_length
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
orig_data_file
,
start_n_top
,
end_n_top
,
version_2_with_negative
,
tokenizer
,
verbose_logging
):
""" XLNet write prediction logic (more complex than Bert's).
Write final predictions to the json file and log-odds of null if needed.
# When we created the data, we kept track of the alignment between original
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
# now `orig_text` contains the span of our original text corresponding to the
# span that we predicted.
#
# However, `orig_text` may contain extra characters that we don't want in
# our prediction.
#
# For example, let's say:
# pred_text = steve smith
# orig_text = Steve Smith's
#
# We don't want to return `orig_text` because it contains the extra "'s".
#
# We don't want to return `pred_text` because it's already been normalized
# (the SQuAD eval script also does punctuation stripping/lower casing but
# our tokenizer does additional normalization like stripping accent
# characters).
#
# What we really want to return is "Steve Smith".
#
# Therefore, we have to apply a semi-complicated alignment heuristic between
# `pred_text` and `orig_text` to get a character-to-character alignment. This
# can fail in certain cases in which case we just return `orig_text`.
Requires utils_squad_evaluate.py
"""
_PrelimPrediction
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
"PrelimPrediction"
,
[
"feature_index"
,
"start_index"
,
"end_index"
,
"start_log_prob"
,
"end_log_prob"
])
def
_strip_spaces
(
text
):
ns_chars
=
[]
ns_to_s_map
=
collections
.
OrderedDict
()
for
(
i
,
c
)
in
enumerate
(
text
):
if
c
==
" "
:
continue
ns_to_s_map
[
len
(
ns_chars
)]
=
i
ns_chars
.
append
(
c
)
ns_text
=
""
.
join
(
ns_chars
)
return
(
ns_text
,
ns_to_s_map
)
_NbestPrediction
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
"NbestPrediction"
,
[
"text"
,
"start_log_prob"
,
"end_log_prob"
])
# We first tokenize `orig_text`, strip whitespace from the result
# and `pred_text`, and check if they are the same length. If they are
# NOT the same length, the heuristic has failed. If they are the same
# length, we assume the characters are one-to-one aligned.
tokenizer
=
BasicTokenizer
(
do_lower_case
=
do_lower_case
)
logger
.
info
(
"Writing predictions to: %s"
,
output_prediction_file
)
# logger.info("Writing nbest to: %s" % (output_nbest_file))
tok_text
=
" "
.
join
(
tokenizer
.
tokenize
(
orig_text
))
example_index_to_features
=
collections
.
defaultdict
(
list
)
for
feature
in
all_features
:
example_index_to_features
[
feature
.
example_index
].
append
(
feature
)
start_position
=
tok_text
.
find
(
pred_text
)
if
start_position
==
-
1
:
if
verbose_logging
:
logger
.
info
(
"Unable to find text: '%s' in '%s'"
%
(
pred_text
,
orig_text
))
return
orig_text
end_position
=
start_position
+
len
(
pred_text
)
-
1
unique_id_to_result
=
{}
for
result
in
all_results
:
unique_id_to_result
[
result
.
unique_id
]
=
result
(
orig_ns_text
,
orig_ns_to_s_map
)
=
_strip_spaces
(
orig_text
)
(
tok_ns_text
,
tok_ns_to_s_map
)
=
_strip_spaces
(
tok_text
)
all_predictions
=
collections
.
OrderedDict
()
all_nbest_json
=
collections
.
OrderedDict
()
scores_diff_json
=
collections
.
OrderedDict
()
if
len
(
orig_ns_text
)
!=
len
(
tok_ns_text
):
if
verbose_logging
:
logger
.
info
(
"Length not equal after stripping spaces: '%s' vs '%s'"
,
orig_ns_text
,
tok_ns_text
)
return
orig_text
for
(
example_index
,
example
)
in
enumerate
(
all_examples
):
features
=
example_index_to_features
[
example_index
]
# We then project the characters in `pred_text` back to `orig_text` using
# the character-to-character alignment.
tok_s_to_ns_map
=
{}
for
(
i
,
tok_index
)
in
tok_ns_to_s_map
.
items
():
tok_s_to_ns_map
[
tok_index
]
=
i
prelim_predictions
=
[]
# keep track of the minimum score of null start+end of position 0
score_null
=
1000000
# large and positive
orig_start_position
=
None
if
start_position
in
tok_s_to_ns_map
:
ns_start_position
=
tok_s_to_ns_map
[
start_position
]
if
ns_start_position
in
orig_ns_to_s_map
:
orig_start_position
=
orig_ns_to_s_map
[
ns_start_position
]
for
(
feature_index
,
feature
)
in
enumerate
(
features
):
result
=
unique_id_to_result
[
feature
.
unique_id
]
if
orig_start_position
is
None
:
if
verbose_logging
:
logger
.
info
(
"Couldn't map start position"
)
return
orig_text
cur_null_score
=
result
.
cls_logits
orig_end_position
=
None
if
end_position
in
tok_s_to_ns_map
:
ns_end_position
=
tok_s_to_ns_map
[
end_position
]
if
ns_end_position
in
orig_ns_to_s_map
:
orig_end_position
=
orig_ns_to_s_map
[
ns_end_position
]
# if we could have irrelevant answers, get the min score of irrelevant
score_null
=
min
(
score_null
,
cur_null_score
)
if
orig_end_position
is
None
:
if
verbose_logging
:
logger
.
info
(
"Couldn't map end position"
)
return
orig_text
for
i
in
range
(
start_n_top
)
:
for
j
in
range
(
end_n_top
)
:
start_log_prob
=
result
.
start_top_log_probs
[
i
]
start_index
=
result
.
start_top_index
[
i
]
output_text
=
orig_text
[
orig_start_position
:(
orig_end_position
+
1
)]
return
output_text
j_index
=
i
*
end_n_top
+
j
end_log_prob
=
result
.
end_top_log_probs
[
j_index
]
end_index
=
result
.
end_top_index
[
j_index
]
def
_get_best_indexes
(
logits
,
n_best_size
):
"""Get the n-best logits from a list."""
index_and_score
=
sorted
(
enumerate
(
logits
),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
# We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all
# invalid predictions.
if
start_index
>=
feature
.
paragraph_len
-
1
:
continue
if
end_index
>=
feature
.
paragraph_len
-
1
:
continue
best_indexes
=
[]
for
i
in
range
(
len
(
index_and_score
)):
if
i
>=
n_best_size
:
break
best_indexes
.
append
(
index_and_score
[
i
][
0
])
return
best_indexes
if
not
feature
.
token_is_max_context
.
get
(
start_index
,
False
):
continue
if
end_index
<
start_index
:
continue
length
=
end_index
-
start_index
+
1
if
length
>
max_answer_length
:
continue
prelim_predictions
.
append
(
_PrelimPrediction
(
feature_index
=
feature_index
,
start_index
=
start_index
,
end_index
=
end_index
,
start_log_prob
=
start_log_prob
,
end_log_prob
=
end_log_prob
))
def
_compute_softmax
(
scores
):
"""Compute softmax probability over raw logits."""
if
not
scores
:
return
[]
prelim_predictions
=
sorted
(
prelim_predictions
,
key
=
lambda
x
:
(
x
.
start_log_prob
+
x
.
end_log_prob
),
reverse
=
True
)
max_score
=
None
for
score
in
scores
:
if
max_score
is
None
or
score
>
max_score
:
max_score
=
score
seen_predictions
=
{}
nbest
=
[]
for
pred
in
prelim_predictions
:
if
len
(
nbest
)
>=
n_best_size
:
break
feature
=
features
[
pred
.
feature_index
]
exp_scores
=
[]
total_sum
=
0.0
for
score
in
scores
:
x
=
math
.
exp
(
score
-
max_score
)
exp_scores
.
append
(
x
)
total_sum
+=
x
# XLNet un-tokenizer
# Let's keep it simple for now and see if we need all this later.
#
# tok_start_to_orig_index = feature.tok_start_to_orig_index
# tok_end_to_orig_index = feature.tok_end_to_orig_index
# start_orig_pos = tok_start_to_orig_index[pred.start_index]
# end_orig_pos = tok_end_to_orig_index[pred.end_index]
# paragraph_text = example.paragraph_text
# final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
# Previously used Bert untokenizer
tok_tokens
=
feature
.
tokens
[
pred
.
start_index
:(
pred
.
end_index
+
1
)]
orig_doc_start
=
feature
.
token_to_orig_map
[
pred
.
start_index
]
orig_doc_end
=
feature
.
token_to_orig_map
[
pred
.
end_index
]
orig_tokens
=
example
.
doc_tokens
[
orig_doc_start
:(
orig_doc_end
+
1
)]
tok_text
=
tokenizer
.
convert_tokens_to_string
(
tok_tokens
)
# Clean whitespace
tok_text
=
tok_text
.
strip
()
tok_text
=
" "
.
join
(
tok_text
.
split
())
orig_text
=
" "
.
join
(
orig_tokens
)
final_text
=
get_final_text
(
tok_text
,
orig_text
,
tokenizer
.
do_lower_case
,
verbose_logging
)
if
final_text
in
seen_predictions
:
continue
probs
=
[]
for
score
in
exp_scores
:
probs
.
append
(
score
/
total_sum
)
return
probs
seen_predictions
[
final_text
]
=
True
nbest
.
append
(
_NbestPrediction
(
text
=
final_text
,
start_log_prob
=
pred
.
start_log_prob
,
end_log_prob
=
pred
.
end_log_prob
))
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if
not
nbest
:
nbest
.
append
(
_NbestPrediction
(
text
=
""
,
start_log_prob
=-
1e6
,
end_log_prob
=-
1e6
))
total_scores
=
[]
best_non_null_entry
=
None
for
entry
in
nbest
:
total_scores
.
append
(
entry
.
start_log_prob
+
entry
.
end_log_prob
)
if
not
best_non_null_entry
:
best_non_null_entry
=
entry
probs
=
_compute_softmax
(
total_scores
)
nbest_json
=
[]
for
(
i
,
entry
)
in
enumerate
(
nbest
):
output
=
collections
.
OrderedDict
()
output
[
"text"
]
=
entry
.
text
output
[
"probability"
]
=
probs
[
i
]
output
[
"start_log_prob"
]
=
entry
.
start_log_prob
output
[
"end_log_prob"
]
=
entry
.
end_log_prob
nbest_json
.
append
(
output
)
assert
len
(
nbest_json
)
>=
1
assert
best_non_null_entry
is
not
None
score_diff
=
score_null
scores_diff_json
[
example
.
qas_id
]
=
score_diff
# note(zhiliny): always predict best_non_null_entry
# and the evaluation script will search for the best threshold
all_predictions
[
example
.
qas_id
]
=
best_non_null_entry
.
text
all_nbest_json
[
example
.
qas_id
]
=
nbest_json
with
open
(
output_prediction_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
all_predictions
,
indent
=
4
)
+
"
\n
"
)
with
open
(
output_nbest_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
all_nbest_json
,
indent
=
4
)
+
"
\n
"
)
if
version_2_with_negative
:
with
open
(
output_null_log_odds_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
scores_diff_json
,
indent
=
4
)
+
"
\n
"
)
with
open
(
orig_data_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
orig_data
=
json
.
load
(
reader
)[
"data"
]
qid_to_has_ans
=
make_qid_to_has_ans
(
orig_data
)
has_ans_qids
=
[
k
for
k
,
v
in
qid_to_has_ans
.
items
()
if
v
]
no_ans_qids
=
[
k
for
k
,
v
in
qid_to_has_ans
.
items
()
if
not
v
]
exact_raw
,
f1_raw
=
get_raw_scores
(
orig_data
,
all_predictions
)
out_eval
=
{}
find_all_best_thresh_v2
(
out_eval
,
all_predictions
,
exact_raw
,
f1_raw
,
scores_diff_json
,
qid_to_has_ans
)
return
out_eval
transformers/data/processors/squad.py
View file @
de276de1
...
...
@@ -306,13 +306,13 @@ class SquadProcessor(DataProcessor):
else
:
is_impossible
=
False
if
not
is_impossible
and
is_training
:
if
(
len
(
qa
[
"answers"
])
!=
1
)
:
raise
ValueError
(
"For training, each question should have exactly 1 answer."
)
answer
=
qa
[
"answers"
][
0
]
answer_text
=
answer
[
'text'
]
start_position_character
=
answer
[
'answer_start'
]
if
not
is_impossible
:
if
is_training
:
answer
=
qa
[
"answers"
][
0
]
answer_text
=
answer
[
'text'
]
start_position_character
=
answer
[
'answer_start'
]
else
:
answers
=
qa
[
"answers"
]
example
=
SquadExample
(
qas_id
=
qas_id
,
...
...
@@ -321,7 +321,8 @@ class SquadProcessor(DataProcessor):
answer_text
=
answer_text
,
start_position_character
=
start_position_character
,
title
=
title
,
is_impossible
=
is_impossible
is_impossible
=
is_impossible
,
answers
=
answers
)
examples
.
append
(
example
)
...
...
@@ -352,6 +353,7 @@ class SquadExample(object):
answer_text
,
start_position_character
,
title
,
answers
=
None
,
is_impossible
=
False
):
self
.
qas_id
=
qas_id
self
.
question_text
=
question_text
...
...
@@ -359,6 +361,7 @@ class SquadExample(object):
self
.
answer_text
=
answer_text
self
.
title
=
title
self
.
is_impossible
=
is_impossible
self
.
answers
=
answers
self
.
start_position
,
self
.
end_position
=
0
,
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment