Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
de276de1
Commit
de276de1
authored
Dec 03, 2019
by
LysandreJik
Browse files
Working evaluation
parent
c835bc85
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
505 additions
and
141 deletions
+505
-141
examples/run_squad.py
examples/run_squad.py
+18
-25
transformers/data/metrics/squad_metrics.py
transformers/data/metrics/squad_metrics.py
+476
-108
transformers/data/processors/squad.py
transformers/data/processors/squad.py
+11
-8
No files found.
examples/run_squad.py
View file @
de276de1
...
@@ -16,7 +16,8 @@
...
@@ -16,7 +16,8 @@
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
from
__future__
import
absolute_import
,
division
,
print_function
from
__future__
import
absolute_import
,
division
,
print_function
from
transformers.data.processors.squad
import
SquadV1Processor
,
SquadV2Processor
from
transformers.data.processors.squad
import
SquadV1Processor
,
SquadV2Processor
,
SquadResult
from
transformers.data.metrics.squad_metrics
import
compute_predictions
,
compute_predictions_extended
,
squad_evaluate
import
argparse
import
argparse
import
logging
import
logging
...
@@ -230,9 +231,11 @@ def evaluate(args, model, tokenizer, prefix=""):
...
@@ -230,9 +231,11 @@ def evaluate(args, model, tokenizer, prefix=""):
model
.
eval
()
model
.
eval
()
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
inputs
=
{
'input_ids'
:
batch
[
0
],
inputs
=
{
'input_ids'
:
batch
[
0
],
'attention_mask'
:
batch
[
1
]
'attention_mask'
:
batch
[
1
]
}
}
if
args
.
model_type
!=
'distilbert'
:
if
args
.
model_type
!=
'distilbert'
:
inputs
[
'token_type_ids'
]
=
None
if
args
.
model_type
==
'xlm'
else
batch
[
2
]
# XLM don't use segment_ids
inputs
[
'token_type_ids'
]
=
None
if
args
.
model_type
==
'xlm'
else
batch
[
2
]
# XLM don't use segment_ids
example_indices
=
batch
[
3
]
example_indices
=
batch
[
3
]
...
@@ -244,18 +247,8 @@ def evaluate(args, model, tokenizer, prefix=""):
...
@@ -244,18 +247,8 @@ def evaluate(args, model, tokenizer, prefix=""):
for
i
,
example_index
in
enumerate
(
example_indices
):
for
i
,
example_index
in
enumerate
(
example_indices
):
eval_feature
=
features
[
example_index
.
item
()]
eval_feature
=
features
[
example_index
.
item
()]
unique_id
=
int
(
eval_feature
.
unique_id
)
unique_id
=
int
(
eval_feature
.
unique_id
)
if
args
.
model_type
in
[
'xlnet'
,
'xlm'
]:
# XLNet uses a more complex post-processing procedure
result
=
SquadResult
([
to_list
(
output
[
i
])
for
output
in
outputs
]
+
[
unique_id
])
result
=
RawResultExtended
(
unique_id
=
unique_id
,
start_top_log_probs
=
to_list
(
outputs
[
0
][
i
]),
start_top_index
=
to_list
(
outputs
[
1
][
i
]),
end_top_log_probs
=
to_list
(
outputs
[
2
][
i
]),
end_top_index
=
to_list
(
outputs
[
3
][
i
]),
cls_logits
=
to_list
(
outputs
[
4
][
i
]))
else
:
result
=
RawResult
(
unique_id
=
unique_id
,
start_logits
=
to_list
(
outputs
[
0
][
i
]),
end_logits
=
to_list
(
outputs
[
1
][
i
]))
all_results
.
append
(
result
)
all_results
.
append
(
result
)
evalTime
=
timeit
.
default_timer
()
-
start_time
evalTime
=
timeit
.
default_timer
()
-
start_time
...
@@ -271,22 +264,18 @@ def evaluate(args, model, tokenizer, prefix=""):
...
@@ -271,22 +264,18 @@ def evaluate(args, model, tokenizer, prefix=""):
if
args
.
model_type
in
[
'xlnet'
,
'xlm'
]:
if
args
.
model_type
in
[
'xlnet'
,
'xlm'
]:
# XLNet uses a more complex post-processing procedure
# XLNet uses a more complex post-processing procedure
wri
te_predictions_extended
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
predictions
=
compu
te_predictions_extended
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
args
.
max_answer_length
,
output_prediction_file
,
args
.
max_answer_length
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
args
.
predict_file
,
output_nbest_file
,
output_null_log_odds_file
,
args
.
predict_file
,
model
.
config
.
start_n_top
,
model
.
config
.
end_n_top
,
model
.
config
.
start_n_top
,
model
.
config
.
end_n_top
,
args
.
version_2_with_negative
,
tokenizer
,
args
.
verbose_logging
)
args
.
version_2_with_negative
,
tokenizer
,
args
.
verbose_logging
)
else
:
else
:
wri
te_predictions
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
predictions
=
compu
te_predictions
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
args
.
max_answer_length
,
args
.
do_lower_case
,
output_prediction_file
,
args
.
max_answer_length
,
args
.
do_lower_case
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
args
.
verbose_logging
,
output_nbest_file
,
output_null_log_odds_file
,
args
.
verbose_logging
,
args
.
version_2_with_negative
,
args
.
null_score_diff_threshold
)
args
.
version_2_with_negative
,
args
.
null_score_diff_threshold
)
# Evaluate with the official SQuAD script
results
=
squad_evaluate
(
examples
,
predictions
)
evaluate_options
=
EVAL_OPTS
(
data_file
=
args
.
predict_file
,
pred_file
=
output_prediction_file
,
na_prob_file
=
output_null_log_odds_file
)
results
=
evaluate_on_squad
(
evaluate_options
)
return
results
return
results
def
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
):
def
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
):
...
@@ -306,7 +295,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
...
@@ -306,7 +295,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
logger
.
info
(
"Creating features from dataset file at %s"
,
input_file
)
logger
.
info
(
"Creating features from dataset file at %s"
,
input_file
)
processor
=
SquadV2Processor
()
processor
=
SquadV2Processor
()
examples
=
processor
.
get_dev_examples
(
"examples/squad"
)
if
evaluate
else
processor
.
get_train_examples
(
"examples/squad"
)
examples
=
processor
.
get_dev_examples
(
"examples/squad"
,
only_first
=
100
)
if
evaluate
else
processor
.
get_train_examples
(
"examples/squad"
)
# import tensorflow_datasets as tfds
# tfds_examples = tfds.load("squad")
# examples = SquadV1Processor().get_examples_from_dataset(tfds_examples["validation"])
features
=
squad_convert_examples_to_features
(
features
=
squad_convert_examples_to_features
(
examples
=
examples
,
examples
=
examples
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
...
transformers/data/metrics/squad_metrics.py
View file @
de276de1
""" Very heavily inspired by the official evaluation script for SQuAD version 2.0 which was
modified by XLNet authors to update `find_best_threshold` scripts for SQuAD V2.0
In addition to basic functionality, we also compute additional statistics and
plot precision-recall curves if an additional na_prob.json file is provided.
This file is expected to map question ID's to the model's predicted probability
that a question is unanswerable.
"""
import
json
import
json
import
logging
import
logging
import
math
import
math
import
collections
import
collections
from
io
import
open
from
io
import
open
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
string
import
re
from
transformers.tokenization_bert
import
BasicTokenizer
,
whitespace_tokenize
from
transformers.tokenization_bert
import
BasicTokenizer
,
whitespace_tokenize
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
def
normalize_answer
(
s
):
"""Lower text and remove punctuation, articles and extra whitespace."""
def
remove_articles
(
text
):
regex
=
re
.
compile
(
r
'\b(a|an|the)\b'
,
re
.
UNICODE
)
return
re
.
sub
(
regex
,
' '
,
text
)
def
white_space_fix
(
text
):
return
' '
.
join
(
text
.
split
())
def
remove_punc
(
text
):
exclude
=
set
(
string
.
punctuation
)
return
''
.
join
(
ch
for
ch
in
text
if
ch
not
in
exclude
)
def
lower
(
text
):
return
text
.
lower
()
return
white_space_fix
(
remove_articles
(
remove_punc
(
lower
(
s
))))
def
get_tokens
(
s
):
if
not
s
:
return
[]
return
normalize_answer
(
s
).
split
()
def
compute_exact
(
a_gold
,
a_pred
):
return
int
(
normalize_answer
(
a_gold
)
==
normalize_answer
(
a_pred
))
def
compute_f1
(
a_gold
,
a_pred
):
gold_toks
=
get_tokens
(
a_gold
)
pred_toks
=
get_tokens
(
a_pred
)
common
=
collections
.
Counter
(
gold_toks
)
&
collections
.
Counter
(
pred_toks
)
num_same
=
sum
(
common
.
values
())
if
len
(
gold_toks
)
==
0
or
len
(
pred_toks
)
==
0
:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return
int
(
gold_toks
==
pred_toks
)
if
num_same
==
0
:
return
0
precision
=
1.0
*
num_same
/
len
(
pred_toks
)
recall
=
1.0
*
num_same
/
len
(
gold_toks
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
return
f1
def
get_raw_scores
(
examples
,
preds
):
"""
Computes the exact and f1 scores from the examples and the model predictions
"""
exact_scores
=
{}
f1_scores
=
{}
for
example
in
examples
:
qas_id
=
example
.
qas_id
gold_answers
=
[
answer
[
'text'
]
for
answer
in
example
.
answers
if
normalize_answer
(
answer
[
'text'
])]
if
not
gold_answers
:
# For unanswerable questions, only correct answer is empty string
gold_answers
=
[
''
]
if
qas_id
not
in
preds
:
print
(
'Missing prediction for %s'
%
qas_id
)
continue
prediction
=
preds
[
qas_id
]
exact_scores
[
qas_id
]
=
max
(
compute_exact
(
a
,
prediction
)
for
a
in
gold_answers
)
f1_scores
[
qas_id
]
=
max
(
compute_f1
(
a
,
prediction
)
for
a
in
gold_answers
)
return
exact_scores
,
f1_scores
def
apply_no_ans_threshold
(
scores
,
na_probs
,
qid_to_has_ans
,
na_prob_thresh
):
new_scores
=
{}
for
qid
,
s
in
scores
.
items
():
pred_na
=
na_probs
[
qid
]
>
na_prob_thresh
if
pred_na
:
new_scores
[
qid
]
=
float
(
not
qid_to_has_ans
[
qid
])
else
:
new_scores
[
qid
]
=
s
return
new_scores
def
make_eval_dict
(
exact_scores
,
f1_scores
,
qid_list
=
None
):
if
not
qid_list
:
total
=
len
(
exact_scores
)
return
collections
.
OrderedDict
([
(
'exact'
,
100.0
*
sum
(
exact_scores
.
values
())
/
total
),
(
'f1'
,
100.0
*
sum
(
f1_scores
.
values
())
/
total
),
(
'total'
,
total
),
])
else
:
total
=
len
(
qid_list
)
return
collections
.
OrderedDict
([
(
'exact'
,
100.0
*
sum
(
exact_scores
[
k
]
for
k
in
qid_list
)
/
total
),
(
'f1'
,
100.0
*
sum
(
f1_scores
[
k
]
for
k
in
qid_list
)
/
total
),
(
'total'
,
total
),
])
def
merge_eval
(
main_eval
,
new_eval
,
prefix
):
for
k
in
new_eval
:
main_eval
[
'%s_%s'
%
(
prefix
,
k
)]
=
new_eval
[
k
]
def
find_best_thresh
(
preds
,
scores
,
na_probs
,
qid_to_has_ans
):
num_no_ans
=
sum
(
1
for
k
in
qid_to_has_ans
if
not
qid_to_has_ans
[
k
])
cur_score
=
num_no_ans
best_score
=
cur_score
best_thresh
=
0.0
qid_list
=
sorted
(
na_probs
,
key
=
lambda
k
:
na_probs
[
k
])
for
_
,
qid
in
enumerate
(
qid_list
):
if
qid
not
in
scores
:
continue
if
qid_to_has_ans
[
qid
]:
diff
=
scores
[
qid
]
else
:
if
preds
[
qid
]:
diff
=
-
1
else
:
diff
=
0
cur_score
+=
diff
if
cur_score
>
best_score
:
best_score
=
cur_score
best_thresh
=
na_probs
[
qid
]
return
100.0
*
best_score
/
len
(
scores
),
best_thresh
def
find_all_best_thresh
(
main_eval
,
preds
,
exact_raw
,
f1_raw
,
na_probs
,
qid_to_has_ans
):
best_exact
,
exact_thresh
=
find_best_thresh
(
preds
,
exact_raw
,
na_probs
,
qid_to_has_ans
)
best_f1
,
f1_thresh
=
find_best_thresh
(
preds
,
f1_raw
,
na_probs
,
qid_to_has_ans
)
main_eval
[
'best_exact'
]
=
best_exact
main_eval
[
'best_exact_thresh'
]
=
exact_thresh
main_eval
[
'best_f1'
]
=
best_f1
main_eval
[
'best_f1_thresh'
]
=
f1_thresh
def
squad_evaluate
(
examples
,
preds
,
no_answer_probs
=
None
,
no_answer_probability_threshold
=
1.0
):
qas_id_to_has_answer
=
{
example
.
qas_id
:
bool
(
example
.
answers
)
for
example
in
examples
}
has_answer_qids
=
[
qas_id
for
qas_id
,
has_answer
in
qas_id_to_has_answer
.
items
()
if
has_answer
]
no_answer_qids
=
[
qas_id
for
qas_id
,
has_answer
in
qas_id_to_has_answer
.
items
()
if
not
has_answer
]
if
no_answer_probs
is
None
:
no_answer_probs
=
{
k
:
0.0
for
k
in
preds
}
exact
,
f1
=
get_raw_scores
(
examples
,
preds
)
exact_threshold
=
apply_no_ans_threshold
(
exact
,
no_answer_probs
,
qas_id_to_has_answer
,
no_answer_probability_threshold
)
f1_threshold
=
apply_no_ans_threshold
(
f1
,
no_answer_probs
,
qas_id_to_has_answer
,
no_answer_probability_threshold
)
evaluation
=
make_eval_dict
(
exact_threshold
,
f1_threshold
)
if
has_answer_qids
:
has_ans_eval
=
make_eval_dict
(
exact_threshold
,
f1_threshold
,
qid_list
=
has_answer_qids
)
merge_eval
(
evaluation
,
has_ans_eval
,
'HasAns'
)
if
no_answer_qids
:
no_ans_eval
=
make_eval_dict
(
exact_threshold
,
f1_threshold
,
qid_list
=
no_answer_qids
)
merge_eval
(
evaluation
,
no_ans_eval
,
'NoAns'
)
if
no_answer_probs
:
find_all_best_thresh
(
evaluation
,
preds
,
exact
,
f1
,
no_answer_probs
,
qas_id_to_has_answer
)
return
evaluation
def
get_final_text
(
pred_text
,
orig_text
,
do_lower_case
,
verbose_logging
=
False
):
"""Project the tokenized prediction back to the original text."""
# When we created the data, we kept track of the alignment between original
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
# now `orig_text` contains the span of our original text corresponding to the
# span that we predicted.
#
# However, `orig_text` may contain extra characters that we don't want in
# our prediction.
#
# For example, let's say:
# pred_text = steve smith
# orig_text = Steve Smith's
#
# We don't want to return `orig_text` because it contains the extra "'s".
#
# We don't want to return `pred_text` because it's already been normalized
# (the SQuAD eval script also does punctuation stripping/lower casing but
# our tokenizer does additional normalization like stripping accent
# characters).
#
# What we really want to return is "Steve Smith".
#
# Therefore, we have to apply a semi-complicated alignment heuristic between
# `pred_text` and `orig_text` to get a character-to-character alignment. This
# can fail in certain cases in which case we just return `orig_text`.
def
_strip_spaces
(
text
):
ns_chars
=
[]
ns_to_s_map
=
collections
.
OrderedDict
()
for
(
i
,
c
)
in
enumerate
(
text
):
if
c
==
" "
:
continue
ns_to_s_map
[
len
(
ns_chars
)]
=
i
ns_chars
.
append
(
c
)
ns_text
=
""
.
join
(
ns_chars
)
return
(
ns_text
,
ns_to_s_map
)
# We first tokenize `orig_text`, strip whitespace from the result
# and `pred_text`, and check if they are the same length. If they are
# NOT the same length, the heuristic has failed. If they are the same
# length, we assume the characters are one-to-one aligned.
tokenizer
=
BasicTokenizer
(
do_lower_case
=
do_lower_case
)
tok_text
=
" "
.
join
(
tokenizer
.
tokenize
(
orig_text
))
start_position
=
tok_text
.
find
(
pred_text
)
if
start_position
==
-
1
:
if
verbose_logging
:
logger
.
info
(
"Unable to find text: '%s' in '%s'"
%
(
pred_text
,
orig_text
))
return
orig_text
end_position
=
start_position
+
len
(
pred_text
)
-
1
(
orig_ns_text
,
orig_ns_to_s_map
)
=
_strip_spaces
(
orig_text
)
(
tok_ns_text
,
tok_ns_to_s_map
)
=
_strip_spaces
(
tok_text
)
if
len
(
orig_ns_text
)
!=
len
(
tok_ns_text
):
if
verbose_logging
:
logger
.
info
(
"Length not equal after stripping spaces: '%s' vs '%s'"
,
orig_ns_text
,
tok_ns_text
)
return
orig_text
# We then project the characters in `pred_text` back to `orig_text` using
# the character-to-character alignment.
tok_s_to_ns_map
=
{}
for
(
i
,
tok_index
)
in
tok_ns_to_s_map
.
items
():
tok_s_to_ns_map
[
tok_index
]
=
i
orig_start_position
=
None
if
start_position
in
tok_s_to_ns_map
:
ns_start_position
=
tok_s_to_ns_map
[
start_position
]
if
ns_start_position
in
orig_ns_to_s_map
:
orig_start_position
=
orig_ns_to_s_map
[
ns_start_position
]
if
orig_start_position
is
None
:
if
verbose_logging
:
logger
.
info
(
"Couldn't map start position"
)
return
orig_text
orig_end_position
=
None
if
end_position
in
tok_s_to_ns_map
:
ns_end_position
=
tok_s_to_ns_map
[
end_position
]
if
ns_end_position
in
orig_ns_to_s_map
:
orig_end_position
=
orig_ns_to_s_map
[
ns_end_position
]
if
orig_end_position
is
None
:
if
verbose_logging
:
logger
.
info
(
"Couldn't map end position"
)
return
orig_text
output_text
=
orig_text
[
orig_start_position
:(
orig_end_position
+
1
)]
return
output_text
def
_get_best_indexes
(
logits
,
n_best_size
):
"""Get the n-best logits from a list."""
index_and_score
=
sorted
(
enumerate
(
logits
),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
best_indexes
=
[]
for
i
in
range
(
len
(
index_and_score
)):
if
i
>=
n_best_size
:
break
best_indexes
.
append
(
index_and_score
[
i
][
0
])
return
best_indexes
def
_compute_softmax
(
scores
):
"""Compute softmax probability over raw logits."""
if
not
scores
:
return
[]
max_score
=
None
for
score
in
scores
:
if
max_score
is
None
or
score
>
max_score
:
max_score
=
score
exp_scores
=
[]
total_sum
=
0.0
for
score
in
scores
:
x
=
math
.
exp
(
score
-
max_score
)
exp_scores
.
append
(
x
)
total_sum
+=
x
probs
=
[]
for
score
in
exp_scores
:
probs
.
append
(
score
/
total_sum
)
return
probs
def
compute_predictions
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
def
compute_predictions
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
max_answer_length
,
do_lower_case
,
output_prediction_file
,
max_answer_length
,
do_lower_case
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
verbose_logging
,
output_nbest_file
,
output_null_log_odds_file
,
verbose_logging
,
...
@@ -204,132 +512,192 @@ def compute_predictions(all_examples, all_features, all_results, n_best_size,
...
@@ -204,132 +512,192 @@ def compute_predictions(all_examples, all_features, all_results, n_best_size,
return
all_predictions
return
all_predictions
def
get_final_text
(
pred_text
,
orig_text
,
do_lower_case
,
verbose_logging
=
False
):
def
compute_predictions_extended
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
"""Project the tokenized prediction back to the original text."""
max_answer_length
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
orig_data_file
,
start_n_top
,
end_n_top
,
version_2_with_negative
,
tokenizer
,
verbose_logging
):
""" XLNet write prediction logic (more complex than Bert's).
Write final predictions to the json file and log-odds of null if needed.
# When we created the data, we kept track of the alignment between original
Requires utils_squad_evaluate.py
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
"""
# now `orig_text` contains the span of our original text corresponding to the
_PrelimPrediction
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
# span that we predicted.
"PrelimPrediction"
,
#
[
"feature_index"
,
"start_index"
,
"end_index"
,
# However, `orig_text` may contain extra characters that we don't want in
"start_log_prob"
,
"end_log_prob"
])
# our prediction.
#
# For example, let's say:
# pred_text = steve smith
# orig_text = Steve Smith's
#
# We don't want to return `orig_text` because it contains the extra "'s".
#
# We don't want to return `pred_text` because it's already been normalized
# (the SQuAD eval script also does punctuation stripping/lower casing but
# our tokenizer does additional normalization like stripping accent
# characters).
#
# What we really want to return is "Steve Smith".
#
# Therefore, we have to apply a semi-complicated alignment heuristic between
# `pred_text` and `orig_text` to get a character-to-character alignment. This
# can fail in certain cases in which case we just return `orig_text`.
def
_strip_spaces
(
text
):
_NbestPrediction
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
ns_chars
=
[]
"NbestPrediction"
,
[
"text"
,
"start_log_prob"
,
"end_log_prob"
])
ns_to_s_map
=
collections
.
OrderedDict
()
for
(
i
,
c
)
in
enumerate
(
text
):
if
c
==
" "
:
continue
ns_to_s_map
[
len
(
ns_chars
)]
=
i
ns_chars
.
append
(
c
)
ns_text
=
""
.
join
(
ns_chars
)
return
(
ns_text
,
ns_to_s_map
)
# We first tokenize `orig_text`, strip whitespace from the result
logger
.
info
(
"Writing predictions to: %s"
,
output_prediction_file
)
# and `pred_text`, and check if they are the same length. If they are
# logger.info("Writing nbest to: %s" % (output_nbest_file))
# NOT the same length, the heuristic has failed. If they are the same
# length, we assume the characters are one-to-one aligned.
tokenizer
=
BasicTokenizer
(
do_lower_case
=
do_lower_case
)
tok_text
=
" "
.
join
(
tokenizer
.
tokenize
(
orig_text
))
example_index_to_features
=
collections
.
defaultdict
(
list
)
for
feature
in
all_features
:
example_index_to_features
[
feature
.
example_index
].
append
(
feature
)
start_position
=
tok_text
.
find
(
pred_text
)
unique_id_to_result
=
{}
if
start_position
==
-
1
:
for
result
in
all_results
:
if
verbose_logging
:
unique_id_to_result
[
result
.
unique_id
]
=
result
logger
.
info
(
"Unable to find text: '%s' in '%s'"
%
(
pred_text
,
orig_text
))
return
orig_text
end_position
=
start_position
+
len
(
pred_text
)
-
1
(
orig_ns_text
,
orig_ns_to_s_map
)
=
_strip_spaces
(
orig_text
)
all_predictions
=
collections
.
OrderedDict
()
(
tok_ns_text
,
tok_ns_to_s_map
)
=
_strip_spaces
(
tok_text
)
all_nbest_json
=
collections
.
OrderedDict
()
scores_diff_json
=
collections
.
OrderedDict
()
if
len
(
orig_ns_text
)
!=
len
(
tok_ns_text
):
for
(
example_index
,
example
)
in
enumerate
(
all_examples
):
if
verbose_logging
:
features
=
example_index_to_features
[
example_index
]
logger
.
info
(
"Length not equal after stripping spaces: '%s' vs '%s'"
,
orig_ns_text
,
tok_ns_text
)
return
orig_text
# We then project the characters in `pred_text` back to `orig_text` using
prelim_predictions
=
[]
# the character-to-character alignment.
# keep track of the minimum score of null start+end of position 0
tok_s_to_ns_map
=
{}
score_null
=
1000000
# large and positive
for
(
i
,
tok_index
)
in
tok_ns_to_s_map
.
items
():
tok_s_to_ns_map
[
tok_index
]
=
i
orig_start_position
=
None
for
(
feature_index
,
feature
)
in
enumerate
(
features
):
if
start_position
in
tok_s_to_ns_map
:
result
=
unique_id_to_result
[
feature
.
unique_id
]
ns_start_position
=
tok_s_to_ns_map
[
start_position
]
if
ns_start_position
in
orig_ns_to_s_map
:
orig_start_position
=
orig_ns_to_s_map
[
ns_start_position
]
if
orig_start_position
is
None
:
cur_null_score
=
result
.
cls_logits
if
verbose_logging
:
logger
.
info
(
"Couldn't map start position"
)
return
orig_text
orig_end_position
=
None
# if we could have irrelevant answers, get the min score of irrelevant
if
end_position
in
tok_s_to_ns_map
:
score_null
=
min
(
score_null
,
cur_null_score
)
ns_end_position
=
tok_s_to_ns_map
[
end_position
]
if
ns_end_position
in
orig_ns_to_s_map
:
orig_end_position
=
orig_ns_to_s_map
[
ns_end_position
]
if
orig_end_position
is
None
:
for
i
in
range
(
start_n_top
)
:
if
verbose_logging
:
for
j
in
range
(
end_n_top
)
:
logger
.
info
(
"Couldn't map end position"
)
start_log_prob
=
result
.
start_top_log_probs
[
i
]
return
orig_text
start_index
=
result
.
start_top_index
[
i
]
output_text
=
orig_text
[
orig_start_position
:(
orig_end_position
+
1
)]
j_index
=
i
*
end_n_top
+
j
return
output_text
end_log_prob
=
result
.
end_top_log_probs
[
j_index
]
end_index
=
result
.
end_top_index
[
j_index
]
def
_get_best_indexes
(
logits
,
n_best_size
):
# We could hypothetically create invalid predictions, e.g., predict
"""Get the n-best logits from a list."""
# that the start of the span is in the question. We throw out all
index_and_score
=
sorted
(
enumerate
(
logits
),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
# invalid predictions.
if
start_index
>=
feature
.
paragraph_len
-
1
:
continue
if
end_index
>=
feature
.
paragraph_len
-
1
:
continue
best_indexes
=
[]
if
not
feature
.
token_is_max_context
.
get
(
start_index
,
False
):
for
i
in
range
(
len
(
index_and_score
)):
continue
if
i
>=
n_best_size
:
if
end_index
<
start_index
:
continue
length
=
end_index
-
start_index
+
1
if
length
>
max_answer_length
:
continue
prelim_predictions
.
append
(
_PrelimPrediction
(
feature_index
=
feature_index
,
start_index
=
start_index
,
end_index
=
end_index
,
start_log_prob
=
start_log_prob
,
end_log_prob
=
end_log_prob
))
prelim_predictions
=
sorted
(
prelim_predictions
,
key
=
lambda
x
:
(
x
.
start_log_prob
+
x
.
end_log_prob
),
reverse
=
True
)
seen_predictions
=
{}
nbest
=
[]
for
pred
in
prelim_predictions
:
if
len
(
nbest
)
>=
n_best_size
:
break
break
best_indexes
.
append
(
index_and_score
[
i
][
0
])
feature
=
features
[
pred
.
feature_index
]
return
best_indexes
# XLNet un-tokenizer
# Let's keep it simple for now and see if we need all this later.
#
# tok_start_to_orig_index = feature.tok_start_to_orig_index
# tok_end_to_orig_index = feature.tok_end_to_orig_index
# start_orig_pos = tok_start_to_orig_index[pred.start_index]
# end_orig_pos = tok_end_to_orig_index[pred.end_index]
# paragraph_text = example.paragraph_text
# final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
# Previously used Bert untokenizer
tok_tokens
=
feature
.
tokens
[
pred
.
start_index
:(
pred
.
end_index
+
1
)]
orig_doc_start
=
feature
.
token_to_orig_map
[
pred
.
start_index
]
orig_doc_end
=
feature
.
token_to_orig_map
[
pred
.
end_index
]
orig_tokens
=
example
.
doc_tokens
[
orig_doc_start
:(
orig_doc_end
+
1
)]
tok_text
=
tokenizer
.
convert_tokens_to_string
(
tok_tokens
)
def
_compute_softmax
(
scores
):
# Clean whitespace
"""Compute softmax probability over raw logits."""
tok_text
=
tok_text
.
strip
()
if
not
scores
:
tok_text
=
" "
.
join
(
tok_text
.
split
())
return
[]
orig_text
=
" "
.
join
(
orig_tokens
)
max_score
=
None
final_text
=
get_final_text
(
tok_text
,
orig_text
,
tokenizer
.
do_lower_case
,
for
score
in
scores
:
verbose_logging
)
if
max_score
is
None
or
score
>
max_score
:
max_score
=
score
exp_scores
=
[]
if
final_text
in
seen_predictions
:
total_sum
=
0.0
continue
for
score
in
scores
:
x
=
math
.
exp
(
score
-
max_score
)
exp_scores
.
append
(
x
)
total_sum
+=
x
probs
=
[]
seen_predictions
[
final_text
]
=
True
for
score
in
exp_scores
:
probs
.
append
(
score
/
total_sum
)
nbest
.
append
(
return
probs
_NbestPrediction
(
text
=
final_text
,
start_log_prob
=
pred
.
start_log_prob
,
end_log_prob
=
pred
.
end_log_prob
))
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if
not
nbest
:
nbest
.
append
(
_NbestPrediction
(
text
=
""
,
start_log_prob
=-
1e6
,
end_log_prob
=-
1e6
))
total_scores
=
[]
best_non_null_entry
=
None
for
entry
in
nbest
:
total_scores
.
append
(
entry
.
start_log_prob
+
entry
.
end_log_prob
)
if
not
best_non_null_entry
:
best_non_null_entry
=
entry
probs
=
_compute_softmax
(
total_scores
)
nbest_json
=
[]
for
(
i
,
entry
)
in
enumerate
(
nbest
):
output
=
collections
.
OrderedDict
()
output
[
"text"
]
=
entry
.
text
output
[
"probability"
]
=
probs
[
i
]
output
[
"start_log_prob"
]
=
entry
.
start_log_prob
output
[
"end_log_prob"
]
=
entry
.
end_log_prob
nbest_json
.
append
(
output
)
assert
len
(
nbest_json
)
>=
1
assert
best_non_null_entry
is
not
None
score_diff
=
score_null
scores_diff_json
[
example
.
qas_id
]
=
score_diff
# note(zhiliny): always predict best_non_null_entry
# and the evaluation script will search for the best threshold
all_predictions
[
example
.
qas_id
]
=
best_non_null_entry
.
text
all_nbest_json
[
example
.
qas_id
]
=
nbest_json
with
open
(
output_prediction_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
all_predictions
,
indent
=
4
)
+
"
\n
"
)
with
open
(
output_nbest_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
all_nbest_json
,
indent
=
4
)
+
"
\n
"
)
if
version_2_with_negative
:
with
open
(
output_null_log_odds_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
scores_diff_json
,
indent
=
4
)
+
"
\n
"
)
with
open
(
orig_data_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
orig_data
=
json
.
load
(
reader
)[
"data"
]
qid_to_has_ans
=
make_qid_to_has_ans
(
orig_data
)
has_ans_qids
=
[
k
for
k
,
v
in
qid_to_has_ans
.
items
()
if
v
]
no_ans_qids
=
[
k
for
k
,
v
in
qid_to_has_ans
.
items
()
if
not
v
]
exact_raw
,
f1_raw
=
get_raw_scores
(
orig_data
,
all_predictions
)
out_eval
=
{}
find_all_best_thresh_v2
(
out_eval
,
all_predictions
,
exact_raw
,
f1_raw
,
scores_diff_json
,
qid_to_has_ans
)
return
out_eval
transformers/data/processors/squad.py
View file @
de276de1
...
@@ -306,13 +306,13 @@ class SquadProcessor(DataProcessor):
...
@@ -306,13 +306,13 @@ class SquadProcessor(DataProcessor):
else
:
else
:
is_impossible
=
False
is_impossible
=
False
if
not
is_impossible
and
is_training
:
if
not
is_impossible
:
if
(
len
(
qa
[
"answers"
])
!=
1
):
if
is_training
:
raise
ValueError
(
"For training, each question should have exactly 1 answer."
)
answer
=
qa
[
"answers"
][
0
]
answer
=
qa
[
"answers"
][
0
]
answer_text
=
answer
[
'text'
]
answer_text
=
answer
[
'text'
]
start_position_character
=
answer
[
'answer_start'
]
start_position_character
=
answer
[
'answer_start'
]
else
:
answers
=
qa
[
"answers"
]
example
=
SquadExample
(
example
=
SquadExample
(
qas_id
=
qas_id
,
qas_id
=
qas_id
,
...
@@ -321,7 +321,8 @@ class SquadProcessor(DataProcessor):
...
@@ -321,7 +321,8 @@ class SquadProcessor(DataProcessor):
answer_text
=
answer_text
,
answer_text
=
answer_text
,
start_position_character
=
start_position_character
,
start_position_character
=
start_position_character
,
title
=
title
,
title
=
title
,
is_impossible
=
is_impossible
is_impossible
=
is_impossible
,
answers
=
answers
)
)
examples
.
append
(
example
)
examples
.
append
(
example
)
...
@@ -352,6 +353,7 @@ class SquadExample(object):
...
@@ -352,6 +353,7 @@ class SquadExample(object):
answer_text
,
answer_text
,
start_position_character
,
start_position_character
,
title
,
title
,
answers
=
None
,
is_impossible
=
False
):
is_impossible
=
False
):
self
.
qas_id
=
qas_id
self
.
qas_id
=
qas_id
self
.
question_text
=
question_text
self
.
question_text
=
question_text
...
@@ -359,6 +361,7 @@ class SquadExample(object):
...
@@ -359,6 +361,7 @@ class SquadExample(object):
self
.
answer_text
=
answer_text
self
.
answer_text
=
answer_text
self
.
title
=
title
self
.
title
=
title
self
.
is_impossible
=
is_impossible
self
.
is_impossible
=
is_impossible
self
.
answers
=
answers
self
.
start_position
,
self
.
end_position
=
0
,
0
self
.
start_position
,
self
.
end_position
=
0
,
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment