Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e363e1d9
Unverified
Commit
e363e1d9
authored
Jun 07, 2021
by
Russell Klopfer
Committed by
GitHub
Jun 07, 2021
Browse files
adds metric prefix. (#12057)
* adds metric prefix. * update tests to include prefix
parent
8994c1e4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
5 deletions
+15
-5
examples/pytorch/question-answering/trainer_qa.py
examples/pytorch/question-answering/trainer_qa.py
+12
-2
examples/pytorch/test_examples.py
examples/pytorch/test_examples.py
+3
-3
No files found.
examples/pytorch/question-answering/trainer_qa.py
View file @
e363e1d9
...
@@ -31,7 +31,7 @@ class QuestionAnsweringTrainer(Trainer):
...
@@ -31,7 +31,7 @@ class QuestionAnsweringTrainer(Trainer):
self
.
eval_examples
=
eval_examples
self
.
eval_examples
=
eval_examples
self
.
post_process_function
=
post_process_function
self
.
post_process_function
=
post_process_function
def
evaluate
(
self
,
eval_dataset
=
None
,
eval_examples
=
None
,
ignore_keys
=
None
):
def
evaluate
(
self
,
eval_dataset
=
None
,
eval_examples
=
None
,
ignore_keys
=
None
,
metric_key_prefix
:
str
=
"eval"
):
eval_dataset
=
self
.
eval_dataset
if
eval_dataset
is
None
else
eval_dataset
eval_dataset
=
self
.
eval_dataset
if
eval_dataset
is
None
else
eval_dataset
eval_dataloader
=
self
.
get_eval_dataloader
(
eval_dataset
)
eval_dataloader
=
self
.
get_eval_dataloader
(
eval_dataset
)
eval_examples
=
self
.
eval_examples
if
eval_examples
is
None
else
eval_examples
eval_examples
=
self
.
eval_examples
if
eval_examples
is
None
else
eval_examples
...
@@ -56,6 +56,11 @@ class QuestionAnsweringTrainer(Trainer):
...
@@ -56,6 +56,11 @@ class QuestionAnsweringTrainer(Trainer):
eval_preds
=
self
.
post_process_function
(
eval_examples
,
eval_dataset
,
output
.
predictions
)
eval_preds
=
self
.
post_process_function
(
eval_examples
,
eval_dataset
,
output
.
predictions
)
metrics
=
self
.
compute_metrics
(
eval_preds
)
metrics
=
self
.
compute_metrics
(
eval_preds
)
# Prefix all keys with metric_key_prefix + '_'
for
key
in
list
(
metrics
.
keys
()):
if
not
key
.
startswith
(
f
"
{
metric_key_prefix
}
_"
):
metrics
[
f
"
{
metric_key_prefix
}
_
{
key
}
"
]
=
metrics
.
pop
(
key
)
self
.
log
(
metrics
)
self
.
log
(
metrics
)
else
:
else
:
metrics
=
{}
metrics
=
{}
...
@@ -67,7 +72,7 @@ class QuestionAnsweringTrainer(Trainer):
...
@@ -67,7 +72,7 @@ class QuestionAnsweringTrainer(Trainer):
self
.
control
=
self
.
callback_handler
.
on_evaluate
(
self
.
args
,
self
.
state
,
self
.
control
,
metrics
)
self
.
control
=
self
.
callback_handler
.
on_evaluate
(
self
.
args
,
self
.
state
,
self
.
control
,
metrics
)
return
metrics
return
metrics
def
predict
(
self
,
predict_dataset
,
predict_examples
,
ignore_keys
=
None
):
def
predict
(
self
,
predict_dataset
,
predict_examples
,
ignore_keys
=
None
,
metric_key_prefix
:
str
=
"test"
):
predict_dataloader
=
self
.
get_test_dataloader
(
predict_dataset
)
predict_dataloader
=
self
.
get_test_dataloader
(
predict_dataset
)
# Temporarily disable metric computation, we will do it in the loop here.
# Temporarily disable metric computation, we will do it in the loop here.
...
@@ -92,4 +97,9 @@ class QuestionAnsweringTrainer(Trainer):
...
@@ -92,4 +97,9 @@ class QuestionAnsweringTrainer(Trainer):
predictions
=
self
.
post_process_function
(
predict_examples
,
predict_dataset
,
output
.
predictions
,
"predict"
)
predictions
=
self
.
post_process_function
(
predict_examples
,
predict_dataset
,
output
.
predictions
,
"predict"
)
metrics
=
self
.
compute_metrics
(
predictions
)
metrics
=
self
.
compute_metrics
(
predictions
)
# Prefix all keys with metric_key_prefix + '_'
for
key
in
list
(
metrics
.
keys
()):
if
not
key
.
startswith
(
f
"
{
metric_key_prefix
}
_"
):
metrics
[
f
"
{
metric_key_prefix
}
_
{
key
}
"
]
=
metrics
.
pop
(
key
)
return
PredictionOutput
(
predictions
=
predictions
.
predictions
,
label_ids
=
predictions
.
label_ids
,
metrics
=
metrics
)
return
PredictionOutput
(
predictions
=
predictions
.
predictions
,
label_ids
=
predictions
.
label_ids
,
metrics
=
metrics
)
examples/pytorch/test_examples.py
View file @
e363e1d9
...
@@ -213,7 +213,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -213,7 +213,7 @@ class ExamplesTests(TestCasePlus):
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
testargs
=
f
"""
testargs
=
f
"""
run_
squad
.py
run_
qa
.py
--model_name_or_path bert-base-uncased
--model_name_or_path bert-base-uncased
--version_2_with_negative
--version_2_with_negative
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
...
@@ -232,8 +232,8 @@ class ExamplesTests(TestCasePlus):
...
@@ -232,8 +232,8 @@ class ExamplesTests(TestCasePlus):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
run_squad
.
main
()
run_squad
.
main
()
result
=
get_results
(
tmp_dir
)
result
=
get_results
(
tmp_dir
)
self
.
assertGreaterEqual
(
result
[
"f1"
],
30
)
self
.
assertGreaterEqual
(
result
[
"
eval_
f1"
],
30
)
self
.
assertGreaterEqual
(
result
[
"exact"
],
30
)
self
.
assertGreaterEqual
(
result
[
"
eval_
exact"
],
30
)
def
test_run_swag
(
self
):
def
test_run_swag
(
self
):
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment