Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
cbc5c9c8
Commit
cbc5c9c8
authored
Mar 28, 2021
by
Leo Gao
Browse files
squad: fix aggregation
parent
14dd29c4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
26 deletions
+36
-26
lm_eval/tasks/squad.py
lm_eval/tasks/squad.py
+36
-26
No files found.
lm_eval/tasks/squad.py
View file @
cbc5c9c8
...
@@ -3,6 +3,19 @@ from math import exp
...
@@ -3,6 +3,19 @@ from math import exp
from
lm_eval.base
import
rf
from
lm_eval.base
import
rf
from
lm_eval.metrics
import
f1_score
,
mean
from
lm_eval.metrics
import
f1_score
,
mean
from
.
common
import
HFTask
from
.
common
import
HFTask
from
functools
import
partial
def
_squad_metric
(
predictions
,
references
):
squad_metric
=
datasets
.
load_metric
(
"squad_v2"
)
return
squad_metric
.
compute
(
predictions
=
predictions
,
references
=
references
)
def
_squad_agg
(
key
,
items
):
predictions
,
references
=
zip
(
*
items
)
return
_squad_metric
(
predictions
=
predictions
,
references
=
references
)[
key
]
class
SQuAD
(
HFTask
):
class
SQuAD
(
HFTask
):
DATASET_PATH
=
"squad_v2"
DATASET_PATH
=
"squad_v2"
...
@@ -63,34 +76,31 @@ class SQuAD(HFTask):
...
@@ -63,34 +76,31 @@ class SQuAD(HFTask):
:param results:
:param results:
The results of the requests created in construct_requests.
The results of the requests created in construct_requests.
"""
"""
squad_metric
=
datasets
.
load_metric
(
"squad_v2"
)
continuation
,
(
logprob_unanswerable
,
_
)
=
results
continuation
,
is_unanswerable
=
results
logprob_unanswerable
,
is_greedy
=
is_unanswerable
no_answer_probability
=
exp
(
logprob_unanswerable
)
no_answer_probability
=
exp
(
logprob_unanswerable
)
predictions
=
[
{
predictions
=
{
'id'
:
doc
[
'id'
],
'id'
:
doc
[
'id'
],
'prediction_text'
:
continuation
,
'prediction_text'
:
continuation
,
'no_answer_probability'
:
no_answer_probability
,
'no_answer_probability'
:
no_answer_probability
,
}
]
}
references
=
[
{
references
=
{
'id'
:
doc
[
'id'
],
'id'
:
doc
[
'id'
],
'answers'
:
doc
[
'answers'
],
'answers'
:
doc
[
'answers'
],
}]
}
metrics
=
squad_metric
.
compute
(
predictions
=
predictions
,
references
=
references
)
metrics
.
pop
(
'total'
,
None
)
metrics
.
pop
(
'HasAns_total'
,
None
)
metrics
.
pop
(
'NoAns_total'
,
None
)
metrics
.
pop
(
'best_exact_thresh'
,
None
)
metrics
.
pop
(
'best_f1_thresh'
,
None
)
return
metrics
return
{
'exact'
:
(
predictions
,
references
),
# Exact match (the normalized answer exactly match the gold answer)
'f1'
:
(
predictions
,
references
),
# The F-score of predicted tokens versus the gold answer
'HasAns_exact'
:
(
predictions
,
references
),
# Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1'
:
(
predictions
,
references
),
# The F-score of predicted tokens versus the gold answer
'NoAns_exact'
:
(
predictions
,
references
),
# Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1'
:
(
predictions
,
references
),
# The F-score of predicted tokens versus the gold answer
'best_exact'
:
(
predictions
,
references
),
# Best exact match (with varying threshold)
'best_f1'
:
(
predictions
,
references
),
# Best F1 (with varying threshold)
}
def
aggregation
(
self
):
def
aggregation
(
self
):
"""
"""
...
@@ -99,14 +109,14 @@ class SQuAD(HFTask):
...
@@ -99,14 +109,14 @@ class SQuAD(HFTask):
functions that aggregate a list of metrics
functions that aggregate a list of metrics
"""
"""
return
{
return
{
'exact'
:
mean
,
# Exact match (the normalized answer exactly match the gold answer)
'exact'
:
partial
(
_squad_agg
,
'exact'
)
,
# Exact match (the normalized answer exactly match the gold answer)
'f1'
:
mean
,
# The F-score of predicted tokens versus the gold answer
'f1'
:
partial
(
_squad_agg
,
'f1'
)
,
# The F-score of predicted tokens versus the gold answer
'HasAns_exact'
:
mean
,
# Exact match (the normalized answer exactly match the gold answer)
'HasAns_exact'
:
partial
(
_squad_agg
,
'HasAns_exact'
)
,
# Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1'
:
mean
,
# The F-score of predicted tokens versus the gold answer
'HasAns_f1'
:
partial
(
_squad_agg
,
'HasAns_f1'
)
,
# The F-score of predicted tokens versus the gold answer
'NoAns_exact'
:
mean
,
# Exact match (the normalized answer exactly match the gold answer)
'NoAns_exact'
:
partial
(
_squad_agg
,
'NoAns_exact'
)
,
# Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1'
:
mean
,
# The F-score of predicted tokens versus the gold answer
'NoAns_f1'
:
partial
(
_squad_agg
,
'NoAns_f1'
)
,
# The F-score of predicted tokens versus the gold answer
'best_exact'
:
mean
,
# Best exact match (with varying threshold)
'best_exact'
:
partial
(
_squad_agg
,
'best_exact'
)
,
# Best exact match (with varying threshold)
'best_f1'
:
mean
,
# Best F1 (with varying threshold)
'best_f1'
:
partial
(
_squad_agg
,
'best_f1'
)
,
# Best F1 (with varying threshold)
}
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment