Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
4acb339e
Commit
4acb339e
authored
Dec 06, 2023
by
lintangsutawika
Browse files
fixed brier score to accomodate samples with different number of choices
parent
835cc40e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
8 deletions
+21
-8
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+21
-8
No files found.
lm_eval/api/metrics.py
View file @
4acb339e
import
math
import
math
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
collections
import
defaultdict
import
numpy
as
np
import
numpy
as
np
import
sacrebleu
import
sacrebleu
import
sklearn.metrics
import
sklearn.metrics
...
@@ -111,13 +111,26 @@ def ter(items):
...
@@ -111,13 +111,26 @@ def ter(items):
@
register_aggregation
(
"brier_score"
)
@
register_aggregation
(
"brier_score"
)
def
brier_score
(
items
):
# This is a passthrough function
def
brier_score
(
items
):
# This is a passthrough function
gold
,
predictions
=
list
(
zip
(
*
items
))
print
(
type
(
predictions
))
# Certain datasets like arc_easy can have a different number of choices.
predictions
=
np
.
array
(
predictions
)
golds
,
predictions
=
list
(
zip
(
*
items
))
print
(
predictions
.
shape
)
gold
=
np
.
array
(
gold
)
pred_group
=
defaultdict
(
list
)
gold_one_hot
=
np
.
eye
(
len
(
predictions
[
0
]))[
gold
]
gold_group
=
defaultdict
(
list
)
return
np
.
mean
(
np
.
sum
((
predictions
-
gold_one_hot
)
**
2
,
axis
=
1
))
for
gold
,
pred
in
zip
(
golds
,
predictions
):
pred_group
[
len
(
pred
)].
append
(
pred
)
gold_group
[
len
(
pred
)].
append
(
gold
)
total_size
=
0
average
=
0
for
g
,
p
in
zip
(
gold_group
.
values
(),
pred_group
.
values
()):
_p
=
np
.
array
(
p
)
_g
=
np
.
array
(
g
)
_g_one_hot
=
np
.
eye
(
len
(
_p
[
0
]))[
_g
]
average
+=
np
.
mean
(
np
.
sum
((
_p
-
_g_one_hot
)
**
2
,
axis
=
1
))
*
len
(
_g
)
total_size
+=
len
(
_g
)
return
average
/
total_size
@
register_metric
(
@
register_metric
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment