Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
de276de1
"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "a32ab5afbd286ef3640b5324ec7056304838804b"
Commit
de276de1
authored
Dec 03, 2019
by
LysandreJik
Browse files
Working evaluation
parent
c835bc85
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
505 additions
and
141 deletions
+505
-141
examples/run_squad.py
examples/run_squad.py
+18
-25
transformers/data/metrics/squad_metrics.py
transformers/data/metrics/squad_metrics.py
+476
-108
transformers/data/processors/squad.py
transformers/data/processors/squad.py
+11
-8
No files found.
examples/run_squad.py
View file @
de276de1
...
...
@@ -16,7 +16,8 @@
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
from
__future__
import
absolute_import
,
division
,
print_function
from
transformers.data.processors.squad
import
SquadV1Processor
,
SquadV2Processor
from
transformers.data.processors.squad
import
SquadV1Processor
,
SquadV2Processor
,
SquadResult
from
transformers.data.metrics.squad_metrics
import
compute_predictions
,
compute_predictions_extended
,
squad_evaluate
import
argparse
import
logging
...
...
@@ -230,9 +231,11 @@ def evaluate(args, model, tokenizer, prefix=""):
model
.
eval
()
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
with
torch
.
no_grad
():
inputs
=
{
'input_ids'
:
batch
[
0
],
'attention_mask'
:
batch
[
1
]
}
inputs
=
{
'input_ids'
:
batch
[
0
],
'attention_mask'
:
batch
[
1
]
}
if
args
.
model_type
!=
'distilbert'
:
inputs
[
'token_type_ids'
]
=
None
if
args
.
model_type
==
'xlm'
else
batch
[
2
]
# XLM don't use segment_ids
example_indices
=
batch
[
3
]
...
...
@@ -244,18 +247,8 @@ def evaluate(args, model, tokenizer, prefix=""):
for
i
,
example_index
in
enumerate
(
example_indices
):
eval_feature
=
features
[
example_index
.
item
()]
unique_id
=
int
(
eval_feature
.
unique_id
)
if
args
.
model_type
in
[
'xlnet'
,
'xlm'
]:
# XLNet uses a more complex post-processing procedure
result
=
RawResultExtended
(
unique_id
=
unique_id
,
start_top_log_probs
=
to_list
(
outputs
[
0
][
i
]),
start_top_index
=
to_list
(
outputs
[
1
][
i
]),
end_top_log_probs
=
to_list
(
outputs
[
2
][
i
]),
end_top_index
=
to_list
(
outputs
[
3
][
i
]),
cls_logits
=
to_list
(
outputs
[
4
][
i
]))
else
:
result
=
RawResult
(
unique_id
=
unique_id
,
start_logits
=
to_list
(
outputs
[
0
][
i
]),
end_logits
=
to_list
(
outputs
[
1
][
i
]))
result
=
SquadResult
([
to_list
(
output
[
i
])
for
output
in
outputs
]
+
[
unique_id
])
all_results
.
append
(
result
)
evalTime
=
timeit
.
default_timer
()
-
start_time
...
...
@@ -271,22 +264,18 @@ def evaluate(args, model, tokenizer, prefix=""):
if
args
.
model_type
in
[
'xlnet'
,
'xlm'
]:
# XLNet uses a more complex post-processing procedure
wri
te_predictions_extended
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
predictions
=
compu
te_predictions_extended
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
args
.
max_answer_length
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
args
.
predict_file
,
model
.
config
.
start_n_top
,
model
.
config
.
end_n_top
,
args
.
version_2_with_negative
,
tokenizer
,
args
.
verbose_logging
)
else
:
wri
te_predictions
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
predictions
=
compu
te_predictions
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
args
.
max_answer_length
,
args
.
do_lower_case
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
args
.
verbose_logging
,
args
.
version_2_with_negative
,
args
.
null_score_diff_threshold
)
# Evaluate with the official SQuAD script
evaluate_options
=
EVAL_OPTS
(
data_file
=
args
.
predict_file
,
pred_file
=
output_prediction_file
,
na_prob_file
=
output_null_log_odds_file
)
results
=
evaluate_on_squad
(
evaluate_options
)
results
=
squad_evaluate
(
examples
,
predictions
)
return
results
def
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
):
...
...
@@ -306,8 +295,12 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
logger
.
info
(
"Creating features from dataset file at %s"
,
input_file
)
processor
=
SquadV2Processor
()
examples
=
processor
.
get_dev_examples
(
"examples/squad"
)
if
evaluate
else
processor
.
get_train_examples
(
"examples/squad"
)
features
=
squad_convert_examples_to_features
(
examples
=
processor
.
get_dev_examples
(
"examples/squad"
,
only_first
=
100
)
if
evaluate
else
processor
.
get_train_examples
(
"examples/squad"
)
# import tensorflow_datasets as tfds
# tfds_examples = tfds.load("squad")
# examples = SquadV1Processor().get_examples_from_dataset(tfds_examples["validation"])
features
=
squad_convert_examples_to_features
(
examples
=
examples
,
tokenizer
=
tokenizer
,
max_seq_length
=
args
.
max_seq_length
,
...
...
transformers/data/metrics/squad_metrics.py
View file @
de276de1
This diff is collapsed.
Click to expand it.
transformers/data/processors/squad.py
View file @
de276de1
...
...
@@ -306,13 +306,13 @@ class SquadProcessor(DataProcessor):
else
:
is_impossible
=
False
if
not
is_impossible
and
is_training
:
if
(
len
(
qa
[
"answers"
])
!=
1
)
:
raise
ValueError
(
"For training, each question should have exactly 1 answer."
)
answer
=
qa
[
"answers"
][
0
]
answer_text
=
answer
[
'text'
]
start_position_character
=
answer
[
'answer_start'
]
if
not
is_impossible
:
if
is_training
:
answer
=
qa
[
"answers"
][
0
]
answer_text
=
answer
[
'text'
]
start_position_character
=
answer
[
'answer_start'
]
else
:
answers
=
qa
[
"answers"
]
example
=
SquadExample
(
qas_id
=
qas_id
,
...
...
@@ -321,7 +321,8 @@ class SquadProcessor(DataProcessor):
answer_text
=
answer_text
,
start_position_character
=
start_position_character
,
title
=
title
,
is_impossible
=
is_impossible
is_impossible
=
is_impossible
,
answers
=
answers
)
examples
.
append
(
example
)
...
...
@@ -352,6 +353,7 @@ class SquadExample(object):
answer_text
,
start_position_character
,
title
,
answers
=
None
,
is_impossible
=
False
):
self
.
qas_id
=
qas_id
self
.
question_text
=
question_text
...
...
@@ -359,6 +361,7 @@ class SquadExample(object):
self
.
answer_text
=
answer_text
self
.
title
=
title
self
.
is_impossible
=
is_impossible
self
.
answers
=
answers
self
.
start_position
,
self
.
end_position
=
0
,
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment