Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
c6a91582
Commit
c6a91582
authored
Dec 27, 2023
by
lintangsutawika
Browse files
update
parent
a808c661
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
22 deletions
+12
-22
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+12
-22
No files found.
lm_eval/api/metrics.py
View file @
c6a91582
import
math
import
math
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
import
abc
import
numpy
as
np
import
numpy
as
np
import
sacrebleu
import
sacrebleu
import
sklearn.metrics
import
sklearn.metrics
...
@@ -17,33 +18,21 @@ eval_logger = logging.getLogger("lm-eval")
...
@@ -17,33 +18,21 @@ eval_logger = logging.getLogger("lm-eval")
class
BaseMetric
:
class
BaseMetric
:
def
__init__
(
def
__init__
(
self
,
self
,
aggregation
=
None
,
)
->
None
:
)
->
None
:
self
.
aggregation
=
aggregation
def
__call__
(
self
,
*
items
):
@
abc
.
abstractmethod
def
update
(
self
,
*
items
):
pass
sample_wise_score
=
self
.
sample_wise_compute
(
*
items
)
@
abc
.
abstractmethod
def
compute
(
self
,
*
items
):
pass
if
self
.
aggregation
is
not
None
:
return
self
.
aggregation
(
sample_wise_score
)
else
:
return
self
.
set_wise_compute
(
sample_wise_score
)
def
sample_wise_compute
(
self
,
*
items
):
return
items
def
set_wise_compute
(
self
,
*
items
):
return
items
# Register Aggregations First
@
register_aggregation
(
"mean"
)
def
mean
(
arr
):
def
mean
(
arr
):
return
sum
(
arr
)
/
len
(
arr
)
return
sum
(
arr
)
/
len
(
arr
)
@
register_aggregation
(
"median"
)
def
median
(
arr
):
def
median
(
arr
):
return
arr
[
len
(
arr
)
//
2
]
return
arr
[
len
(
arr
)
//
2
]
...
@@ -54,10 +43,10 @@ def median(arr):
...
@@ -54,10 +43,10 @@ def median(arr):
output_type
=
"loglikelihood"
,
output_type
=
"loglikelihood"
,
)
)
class
PerplexityMetric
(
BaseMetric
):
class
PerplexityMetric
(
BaseMetric
):
def
sample_wise_compu
te
(
self
,
ll
,
is_greedy
):
def
upda
te
(
self
,
ll
,
is_greedy
):
return
ll
return
ll
def
set_wise_
compute
(
self
,
items
):
def
compute
(
self
,
items
):
return
math
.
exp
(
-
mean
(
items
))
return
math
.
exp
(
-
mean
(
items
))
...
@@ -65,12 +54,13 @@ class PerplexityMetric(BaseMetric):
...
@@ -65,12 +54,13 @@ class PerplexityMetric(BaseMetric):
metric
=
"acc"
,
metric
=
"acc"
,
higher_is_better
=
True
,
higher_is_better
=
True
,
output_type
=
"loglikelihood"
,
output_type
=
"loglikelihood"
,
aggregation
=
"mean"
,
)
)
class
LoglikelihoodAccMetric
(
BaseMetric
):
class
LoglikelihoodAccMetric
(
BaseMetric
):
def
__call__
(
self
,
ll
,
is_greedy
):
def
update
(
self
,
ll
,
is_greedy
):
return
int
(
is_greedy
)
return
int
(
is_greedy
)
def
compute
(
self
,
items
):
return
math
.
exp
(
-
mean
(
items
))
@
register_aggregation
(
"f1"
)
@
register_aggregation
(
"f1"
)
def
f1_score
(
items
):
def
f1_score
(
items
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment