Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
45fefe9f
Commit
45fefe9f
authored
May 02, 2021
by
Leo Gao
Browse files
implement stderr calculation
parent
8846bec0
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
0 deletions
+72
-0
lm_eval/evaluator.py
lm_eval/evaluator.py
+5
-0
lm_eval/metrics.py
lm_eval/metrics.py
+67
-0
No files found.
lm_eval/evaluator.py
View file @
45fefe9f
import
collections
import
collections
import
itertools
import
itertools
import
random
import
random
import
lm_eval.metrics
def
evaluate
(
lm
,
task_dict
,
provide_description
,
num_fewshot
,
limit
):
def
evaluate
(
lm
,
task_dict
,
provide_description
,
num_fewshot
,
limit
):
...
@@ -89,4 +90,8 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
...
@@ -89,4 +90,8 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
task
=
task_dict
[
task_name
]
task
=
task_dict
[
task_name
]
results
[
task_name
][
metric
]
=
task
.
aggregation
()[
metric
](
items
)
results
[
task_name
][
metric
]
=
task
.
aggregation
()[
metric
](
items
)
stderr
=
lm_eval
.
metrics
.
stderr_for_metric
(
task
.
aggregation
()[
metric
])
if
stderr
is
not
None
:
results
[
task_name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
return
results
return
results
lm_eval/metrics.py
View file @
45fefe9f
...
@@ -5,12 +5,23 @@ from pprint import pprint
...
@@ -5,12 +5,23 @@ from pprint import pprint
import
numpy
as
np
import
numpy
as
np
import
sacrebleu
import
sacrebleu
import
sklearn
import
sklearn
import
random
def
mean
(
arr
):
def
mean
(
arr
):
return
sum
(
arr
)
/
len
(
arr
)
return
sum
(
arr
)
/
len
(
arr
)
def
stddev
(
arr
):
mu
=
mean
(
arr
)
return
math
.
sqrt
(
sum
([(
x
-
mu
)
**
2
for
x
in
arr
])
/
len
(
arr
))
def
mean_stderr
(
arr
):
print
(
stddev
(
arr
),
len
(
arr
))
return
stddev
(
arr
)
/
math
.
sqrt
(
len
(
arr
))
def
median
(
arr
):
def
median
(
arr
):
return
arr
[
len
(
arr
)
//
2
]
return
arr
[
len
(
arr
)
//
2
]
...
@@ -48,6 +59,23 @@ def acc_all(items):
...
@@ -48,6 +59,23 @@ def acc_all(items):
acc
=
np
.
mean
([
int
(
all
(
x
))
for
x
in
question_scoring_dict
.
values
()])
acc
=
np
.
mean
([
int
(
all
(
x
))
for
x
in
question_scoring_dict
.
values
()])
return
acc
return
acc
def
acc_all_stderr
(
items
):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict
=
{}
preds
=
list
(
zip
(
*
items
))[
0
]
docs
=
list
(
zip
(
*
items
))[
1
]
for
doc
,
pred
in
zip
(
docs
,
preds
):
question_id
=
doc
[
"idx"
][
"question"
]
if
question_id
not
in
question_scoring_dict
:
question_scoring_dict
[
question_id
]
=
[]
gold_label
=
doc
[
"label"
]
==
1
question_scoring_dict
[
question_id
].
append
(
gold_label
==
pred
)
acc
=
mean_stderr
([
int
(
all
(
x
))
for
x
in
question_scoring_dict
.
values
()])
return
acc
def
metric_max_over_ground_truths
(
metric_fn
,
prediction
,
ground_truths
):
def
metric_max_over_ground_truths
(
metric_fn
,
prediction
,
ground_truths
):
"""Compute max metric between prediction and each ground truth."""
"""Compute max metric between prediction and each ground truth."""
...
@@ -138,3 +166,42 @@ def _sacreformat(refs, preds):
...
@@ -138,3 +166,42 @@ def _sacreformat(refs, preds):
preds
=
[
pred
[
0
]
for
pred
in
preds
]
preds
=
[
pred
[
0
]
for
pred
in
preds
]
return
refs
,
preds
return
refs
,
preds
## stderr stuff
def
bootstrap_stddev
(
f
,
xs
,
iters
=
10000
):
rnd
=
random
.
Random
()
rnd
.
seed
(
42
)
res
=
[]
from
tqdm
import
trange
print
(
"bootstrapping for stddev:"
,
f
.
__name__
)
for
i
in
trange
(
iters
):
# sample w replacement
bootstrap
=
rnd
.
choices
(
xs
,
k
=
len
(
xs
))
res
.
append
(
stddev
(
bootstrap
))
return
mean
(
res
)
def
stderr_for_metric
(
metric
):
bootstrappable
=
[
median
,
matthews_corrcoef
,
f1_score
,
perplexity
,
bleu
,
chrf
,
ter
,
]
if
metric
in
bootstrappable
:
return
lambda
x
:
bootstrap_stddev
(
metric
,
x
)
/
math
.
sqrt
(
len
(
x
))
stderr
=
{
mean
:
mean_stderr
,
acc_all
:
acc_all_stderr
}
return
stderr
.
get
(
metric
,
None
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment