Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
aaf64aab
Commit
aaf64aab
authored
Jan 02, 2024
by
lintangsutawika
Browse files
readded suport for aggregation
parent
439dca55
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
95 additions
and
47 deletions
+95
-47
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+5
-1
lm_eval/api/registry.py
lm_eval/api/registry.py
+48
-31
lm_eval/api/task.py
lm_eval/api/task.py
+42
-15
No files found.
lm_eval/api/metrics.py
View file @
aaf64aab
...
@@ -8,20 +8,23 @@ import numpy as np
...
@@ -8,20 +8,23 @@ import numpy as np
import
sacrebleu
import
sacrebleu
import
sklearn.metrics
import
sklearn.metrics
from
lm_eval.api.registry
import
register_metric
from
lm_eval.api.registry
import
register_metric
,
register_aggregation
eval_logger
=
logging
.
getLogger
(
"lm-eval"
)
eval_logger
=
logging
.
getLogger
(
"lm-eval"
)
@
register_aggregation
(
"mean"
)
def
mean
(
arr
):
def
mean
(
arr
):
return
sum
(
arr
)
/
len
(
arr
)
return
sum
(
arr
)
/
len
(
arr
)
@
register_aggregation
(
"median"
)
def
median
(
arr
):
def
median
(
arr
):
return
arr
[
len
(
arr
)
//
2
]
return
arr
[
len
(
arr
)
//
2
]
@
register_aggregation
(
"weighted_mean"
)
def
weighted_mean
(
items
):
def
weighted_mean
(
items
):
a
,
b
=
zip
(
*
items
)
a
,
b
=
zip
(
*
items
)
return
sum
(
a
)
/
sum
(
b
)
return
sum
(
a
)
/
sum
(
b
)
...
@@ -161,6 +164,7 @@ def acc_mutual_info_fn(items):
...
@@ -161,6 +164,7 @@ def acc_mutual_info_fn(items):
exact_match
=
evaluate
.
load
(
"exact_match"
)
exact_match
=
evaluate
.
load
(
"exact_match"
)
@
register_metric
(
@
register_metric
(
metric
=
"exact_match"
,
metric
=
"exact_match"
,
higher_is_better
=
True
,
higher_is_better
=
True
,
...
...
lm_eval/api/registry.py
View file @
aaf64aab
import
os
import
os
import
logging
import
logging
import
evaluate
import
evaluate
import
collections
from
functools
import
partial
from
functools
import
partial
from
lm_eval.api.model
import
LM
from
lm_eval.api.model
import
LM
...
@@ -9,21 +10,6 @@ eval_logger = logging.getLogger("lm-eval")
...
@@ -9,21 +10,6 @@ eval_logger = logging.getLogger("lm-eval")
MODEL_REGISTRY
=
{}
MODEL_REGISTRY
=
{}
class
HFEvaluateAdaptor
:
def
__init__
(
self
,
name
,
**
kwargs
):
self
.
name
=
name
metric_object
=
evaluate
.
load
(
name
)
self
.
hf_evaluate_fn
=
partial
(
metric_object
.
compute
,
**
kwargs
)
def
__call__
(
self
,
items
):
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
return
self
.
hf_evaluate_fn
(
references
=
refs
,
predictions
=
preds
)[
self
.
name
]
def
register_model
(
*
names
):
def
register_model
(
*
names
):
# either pass a list or a single alias.
# either pass a list or a single alias.
...
@@ -87,8 +73,8 @@ def register_group(name):
...
@@ -87,8 +73,8 @@ def register_group(name):
return
decorate
return
decorate
METRIC_
FUNCTION_
REGISTRY
=
{}
METRIC_REGISTRY
=
collections
.
defaultdict
(
dict
)
HIGHER_IS_BETTER_REGISTRY
=
{}
AGGREGATION_REGISTRY
=
collections
.
defaultdict
(
dict
)
DEFAULT_METRIC_REGISTRY
=
{
DEFAULT_METRIC_REGISTRY
=
{
"loglikelihood"
:
[],
"loglikelihood"
:
[],
...
@@ -102,6 +88,7 @@ def register_metric(
...
@@ -102,6 +88,7 @@ def register_metric(
metric
=
None
,
metric
=
None
,
higher_is_better
=
None
,
higher_is_better
=
None
,
output_type
=
None
,
output_type
=
None
,
aggregation
=
None
,
):
):
# TODO: do we want to enforce a certain interface to registered metrics?
# TODO: do we want to enforce a certain interface to registered metrics?
def
decorate
(
fn
):
def
decorate
(
fn
):
...
@@ -112,10 +99,13 @@ def register_metric(
...
@@ -112,10 +99,13 @@ def register_metric(
metric_list
=
metric
metric_list
=
metric
for
_metric
in
metric_list
:
for
_metric
in
metric_list
:
METRIC_FUNCTION_REGISTRY
[
_metric
]
=
fn
METRIC_REGISTRY
[
_metric
][
"function"
]
=
fn
if
aggregation
is
not
None
:
METRIC_REGISTRY
[
_metric
][
"aggregation"
]
=
aggregation
if
higher_is_better
is
not
None
:
if
higher_is_better
is
not
None
:
HIGHER_IS_BETTER
_REGISTRY
[
_metric
]
=
higher_is_better
METRIC
_REGISTRY
[
_metric
]
[
"higher_is_better"
]
=
higher_is_better
if
output_type
is
not
None
:
if
output_type
is
not
None
:
if
type
(
output_type
)
==
str
:
if
type
(
output_type
)
==
str
:
...
@@ -131,18 +121,33 @@ def register_metric(
...
@@ -131,18 +121,33 @@ def register_metric(
return
decorate
return
decorate
def
get_metric
(
name
,
hf_evaluate_metric
=
False
,
**
kwargs
):
def
get_metric
(
name
):
if
n
ot
hf_evaluate_metric
:
if
n
ame
in
METRIC_REGISTRY
:
if
name
in
METRIC_FUNCTION
_REGISTRY
:
return
METRIC
_REGISTRY
[
name
]
return
METRIC_FUNCTION_REGISTRY
[
name
]
else
:
e
lse
:
e
val_logger
.
error
(
f
"Could not find registered metric '
{
name
}
' in lm-eval"
)
eval_logger
.
warning
(
f
"Could not find registered metric '
{
name
}
' in lm-eval, searching in HF Evaluate library..."
)
def
get_evaluate
(
name
,
**
kwargs
):
try
:
try
:
# from lm_eval.metrics import HFEvaluateAdaptor
class
HFEvaluateAdaptor
:
def
__init__
(
self
,
name
,
**
kwargs
):
self
.
name
=
name
metric_object
=
evaluate
.
load
(
name
)
self
.
hf_evaluate_fn
=
partial
(
metric_object
.
compute
,
**
kwargs
)
def
__call__
(
self
,
items
):
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
return
self
.
hf_evaluate_fn
(
references
=
refs
,
predictions
=
preds
)[
self
.
name
]
return
HFEvaluateAdaptor
(
name
,
**
kwargs
)
return
HFEvaluateAdaptor
(
name
,
**
kwargs
)
except
Exception
:
except
Exception
:
eval_logger
.
error
(
eval_logger
.
error
(
...
@@ -150,10 +155,22 @@ def get_metric(name, hf_evaluate_metric=False, **kwargs):
...
@@ -150,10 +155,22 @@ def get_metric(name, hf_evaluate_metric=False, **kwargs):
)
)
def
is_higher_better
(
metric_name
):
def
register_aggregation
(
name
):
def
decorate
(
fn
):
assert
(
name
not
in
AGGREGATION_REGISTRY
),
f
"aggregation named '
{
name
}
' conflicts with existing registered aggregation!"
AGGREGATION_REGISTRY
[
name
]
=
fn
return
fn
return
decorate
def
get_aggregation
(
name
):
try
:
try
:
return
HIGHER_IS_BETTER
_REGISTRY
[
metric_
name
]
return
AGGREGATION
_REGISTRY
[
name
]
except
KeyError
:
except
KeyError
:
eval_logger
.
warning
(
eval_logger
.
warning
(
f
"higher_is_better not specified for metric '
{
metric_name
}
'!"
"{} not a registered aggregation metric!"
.
format
(
name
),
)
)
lm_eval/api/task.py
View file @
aaf64aab
...
@@ -32,7 +32,9 @@ from lm_eval.api.metrics import (
...
@@ -32,7 +32,9 @@ from lm_eval.api.metrics import (
)
)
from
lm_eval.api.registry
import
(
from
lm_eval.api.registry
import
(
get_metric
,
get_metric
,
is_higher_better
,
get_evaluate
,
get_aggregation
,
METRIC_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
)
)
...
@@ -410,7 +412,7 @@ class Task(abc.ABC):
...
@@ -410,7 +412,7 @@ class Task(abc.ABC):
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
compute_metric
(
self
):
def
aggregation
(
self
):
"""
"""
:returns: {str: [metric_score] -> float}
:returns: {str: [metric_score] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
...
@@ -553,6 +555,7 @@ class ConfigurableTask(Task):
...
@@ -553,6 +555,7 @@ class ConfigurableTask(Task):
self
.
_metric_fn_list
=
{}
self
.
_metric_fn_list
=
{}
self
.
_metric_fn_kwargs
=
{}
self
.
_metric_fn_kwargs
=
{}
self
.
_aggregation_list
=
{}
self
.
_higher_is_better
=
{}
self
.
_higher_is_better
=
{}
if
self
.
config
.
metric_list
is
None
:
if
self
.
config
.
metric_list
is
None
:
...
@@ -561,12 +564,14 @@ class ConfigurableTask(Task):
...
@@ -561,12 +564,14 @@ class ConfigurableTask(Task):
for
metric_name
in
_metric_list
:
for
metric_name
in
_metric_list
:
metric
=
get_metric
(
metric_name
)
metric
=
get_metric
(
metric_name
)
self
.
_metric_fn_list
[
metric_name
]
=
metric
self
.
_metric_fn_list
[
metric_name
]
=
metric
[
"function"
]
self
.
_metric_fn_kwargs
[
metric_name
]
=
{}
self
.
_metric_fn_kwargs
[
metric_name
]
=
{}
self
.
_higher_is_better
[
metric_name
]
=
is_higher_better
(
metric_name
)
self
.
_aggregation_list
=
metric
[
"aggregation"
]
self
.
_higher_is_better
[
metric_name
]
=
metric
[
"is_higher_better"
]
else
:
else
:
for
metric_config
in
self
.
config
.
metric_list
:
for
metric_config
in
self
.
config
.
metric_list
:
assert
"metric"
in
metric_config
assert
"metric"
in
metric_config
from_registry
=
False
metric_name
=
metric_config
[
"metric"
]
metric_name
=
metric_config
[
"metric"
]
kwargs
=
{
kwargs
=
{
key
:
metric_config
[
key
]
key
:
metric_config
[
key
]
...
@@ -574,25 +579,47 @@ class ConfigurableTask(Task):
...
@@ -574,25 +579,47 @@ class ConfigurableTask(Task):
if
key
if
key
not
in
[
"metric"
,
"aggregation"
,
"higher_is_better"
,
"hf_evaluate"
]
not
in
[
"metric"
,
"aggregation"
,
"higher_is_better"
,
"hf_evaluate"
]
}
}
hf_evaluate
_metric
=
(
use_
hf_evaluate
=
(
"hf_evaluate"
in
metric_config
"hf_evaluate"
in
metric_config
and
metric_config
[
"hf_evaluate"
]
is
True
and
metric_config
[
"hf_evaluate"
]
is
True
)
)
# if self.config.process_results is not None:
# self._metric_fn_list[metric_name] = None
# self._metric_fn_kwargs[metric_name] = {}
if
callable
(
metric_name
):
if
callable
(
metric_name
):
metric_fn
=
metric_name
.
__call__
metric_fn
=
metric_name
.
__call__
metric_name
=
metric_name
.
__name__
metric_name
=
metric_name
.
__name__
else
:
else
:
metric_fn
=
get_metric
(
assert
type
(
metric_name
)
==
str
metric_name
,
hf_evaluate_metric
,
**
kwargs
if
use_hf_evaluate
:
)
metric_fn
=
get_evaluate
(
metric_name
,
**
kwargs
)
elif
metric_name
in
METRIC_REGISTRY
:
from_registry
=
True
metric
=
get_metric
(
metric_name
,
**
kwargs
)
metric_fn
=
metric
[
"function"
]
self
.
_metric_fn_kwargs
[
metric_name
]
=
kwargs
self
.
_metric_fn_kwargs
[
metric_name
]
=
kwargs
self
.
_metric_fn_list
[
metric_name
]
=
metric_fn
self
.
_metric_fn_list
[
metric_name
]
=
metric_fn
if
"aggregation"
in
metric_config
:
agg_name
=
metric_config
[
"aggregation"
]
if
isinstance
(
agg_name
,
str
):
self
.
_aggregation_list
[
metric_name
]
=
get_aggregation
(
agg_name
)
elif
callable
(
agg_name
):
# noqa: E721
self
.
_aggregation_list
[
metric_name
]
=
agg_name
else
:
if
from_registry
:
if
"aggregation"
in
metric
:
self
.
_aggregation_list
[
metric_name
]
=
metric
[
"aggregation"
]
else
:
self
.
_aggregation_list
[
metric_name
]
=
metric_fn
if
"higher_is_better"
in
metric_config
:
self
.
_higher_is_better
[
metric_name
]
=
metric_config
[
"higher_is_better"
]
else
:
if
from_registry
:
self
.
_higher_is_better
[
metric_name
]
=
metric
[
"higher_is_better"
]
self
.
download
(
self
.
config
.
dataset_kwargs
)
self
.
download
(
self
.
config
.
dataset_kwargs
)
self
.
_training_docs
=
None
self
.
_training_docs
=
None
self
.
_fewshot_docs
=
None
self
.
_fewshot_docs
=
None
...
@@ -1157,8 +1184,8 @@ class ConfigurableTask(Task):
...
@@ -1157,8 +1184,8 @@ class ConfigurableTask(Task):
return
result_dict
return
result_dict
def
compute_metric
(
self
):
def
aggregation
(
self
):
return
self
.
_
metric_f
n_list
return
self
.
_
aggregatio
n_list
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
self
.
_higher_is_better
return
self
.
_higher_is_better
...
@@ -1204,7 +1231,7 @@ class MultipleChoiceTask(Task):
...
@@ -1204,7 +1231,7 @@ class MultipleChoiceTask(Task):
"acc_norm"
:
True
,
"acc_norm"
:
True
,
}
}
def
compute_metric
(
self
)
->
dict
:
def
aggregation
(
self
)
->
dict
:
return
{
return
{
"acc"
:
mean
,
"acc"
:
mean
,
"acc_norm"
:
mean
,
"acc_norm"
:
mean
,
...
@@ -1265,7 +1292,7 @@ class PerplexityTask(Task):
...
@@ -1265,7 +1292,7 @@ class PerplexityTask(Task):
"bits_per_byte"
:
(
loglikelihood
,
bytes_
),
"bits_per_byte"
:
(
loglikelihood
,
bytes_
),
}
}
def
compute_metric
(
self
)
->
dict
:
def
aggregation
(
self
)
->
dict
:
return
{
return
{
"word_perplexity"
:
weighted_perplexity
,
"word_perplexity"
:
weighted_perplexity
,
"byte_perplexity"
:
weighted_perplexity
,
"byte_perplexity"
:
weighted_perplexity
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment