Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
1b5c6f88
Commit
1b5c6f88
authored
Jun 30, 2025
by
Baber
Browse files
add MetricConfig
parent
6b3f3f7e
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
215 additions
and
154 deletions
+215
-154
lm_eval/api/group.py
lm_eval/api/group.py
+1
-2
lm_eval/api/instance.py
lm_eval/api/instance.py
+17
-4
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+6
-1
lm_eval/api/registry.py
lm_eval/api/registry.py
+13
-10
lm_eval/api/samplers.py
lm_eval/api/samplers.py
+3
-3
lm_eval/api/task.py
lm_eval/api/task.py
+170
-129
lm_eval/evaluator.py
lm_eval/evaluator.py
+1
-1
lm_eval/evaluator_utils.py
lm_eval/evaluator_utils.py
+2
-2
lm_eval/filters/__init__.py
lm_eval/filters/__init__.py
+2
-2
No files found.
lm_eval/api/group.py
View file @
1b5c6f88
import
abc
from
dataclasses
import
asdict
,
dataclass
from
dataclasses
import
asdict
,
dataclass
from
inspect
import
getsource
from
inspect
import
getsource
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
...
@@ -84,7 +83,7 @@ class GroupConfig(dict):
...
@@ -84,7 +83,7 @@ class GroupConfig(dict):
return
str
(
value
)
return
str
(
value
)
class
ConfigurableGroup
(
abc
.
ABC
)
:
class
ConfigurableGroup
:
def
__init__
(
def
__init__
(
self
,
self
,
config
:
Optional
[
dict
]
=
None
,
config
:
Optional
[
dict
]
=
None
,
...
...
lm_eval/api/instance.py
View file @
1b5c6f88
...
@@ -14,10 +14,23 @@ class Instance:
...
@@ -14,10 +14,23 @@ class Instance:
arguments
:
tuple
arguments
:
tuple
idx
:
int
idx
:
int
metadata
:
Tuple
[
Optional
[
str
],
Optional
[
int
],
Optional
[
int
]]
=
field
(
metadata
:
Tuple
[
Optional
[
str
],
Optional
[
int
],
Optional
[
int
]]
=
field
(
default_factory
=
lambda
:
(
None
,
None
,
None
)
default_factory
=
lambda
:
(
None
,
None
,
None
),
metadata
=
dict
(
description
=
"Metadata tuple containing task name, document ID, and number of repeats."
),
)
resps
:
list
=
field
(
default_factory
=
list
,
metadata
=
dict
(
description
=
"List of responses from the model for this instance."
),
)
filtered_resps
:
dict
=
field
(
default_factory
=
dict
,
metadata
=
dict
(
description
=
"List of filtered responses for this instance, keyed by filter name."
),
)
)
resps
:
list
=
field
(
default_factory
=
list
)
filtered_resps
:
dict
=
field
(
default_factory
=
dict
)
# initialized after init
# initialized after init
task_name
:
Optional
[
str
]
=
None
task_name
:
Optional
[
str
]
=
None
...
@@ -29,7 +42,7 @@ class Instance:
...
@@ -29,7 +42,7 @@ class Instance:
self
.
task_name
,
self
.
doc_id
,
self
.
repeats
=
self
.
metadata
self
.
task_name
,
self
.
doc_id
,
self
.
repeats
=
self
.
metadata
@
property
@
property
def
args
(
self
):
def
args
(
self
)
->
tuple
:
"""
"""
Returns (string,) where `string` is the string to calculate loglikelihood over
Returns (string,) where `string` is the string to calculate loglikelihood over
"""
"""
...
...
lm_eval/api/metrics.py
View file @
1b5c6f88
...
@@ -7,7 +7,6 @@ from collections.abc import Iterable
...
@@ -7,7 +7,6 @@ from collections.abc import Iterable
from
typing
import
List
from
typing
import
List
import
numpy
as
np
import
numpy
as
np
import
sacrebleu
from
lm_eval.api.registry
import
register_aggregation
,
register_metric
from
lm_eval.api.registry
import
register_aggregation
,
register_metric
...
@@ -89,6 +88,8 @@ def bleu(items):
...
@@ -89,6 +88,8 @@ def bleu(items):
Higher is better
Higher is better
"""
"""
import
sacrebleu
refs
=
list
(
zip
(
*
items
))[
0
]
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
...
@@ -104,6 +105,8 @@ def chrf(items):
...
@@ -104,6 +105,8 @@ def chrf(items):
Higher is better # TODO I think
Higher is better # TODO I think
"""
"""
import
sacrebleu
refs
=
list
(
zip
(
*
items
))[
0
]
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
...
@@ -120,6 +123,8 @@ def ter(items):
...
@@ -120,6 +123,8 @@ def ter(items):
Lower is better
Lower is better
"""
"""
import
sacrebleu
refs
=
list
(
zip
(
*
items
))[
0
]
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
...
...
lm_eval/api/registry.py
View file @
1b5c6f88
import
logging
import
logging
from
typing
import
Callable
,
Dict
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
,
Dict
,
Optional
,
Union
import
evaluate
as
hf_evaluate
from
lm_eval.api.model
import
LM
if
TYPE_CHECKING
:
from
lm_eval.api.model
import
LM
eval_logger
=
logging
.
getLogger
(
__name__
)
eval_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -12,6 +11,8 @@ MODEL_REGISTRY = {}
...
@@ -12,6 +11,8 @@ MODEL_REGISTRY = {}
def
register_model
(
*
names
):
def
register_model
(
*
names
):
from
lm_eval.api.model
import
LM
# either pass a list or a single alias.
# either pass a list or a single alias.
# function receives them as a tuple of strings
# function receives them as a tuple of strings
...
@@ -31,7 +32,7 @@ def register_model(*names):
...
@@ -31,7 +32,7 @@ def register_model(*names):
return
decorate
return
decorate
def
get_model
(
model_name
)
:
def
get_model
(
model_name
:
str
)
->
type
[
"LM"
]
:
try
:
try
:
return
MODEL_REGISTRY
[
model_name
]
return
MODEL_REGISTRY
[
model_name
]
except
KeyError
:
except
KeyError
:
...
@@ -46,7 +47,7 @@ ALL_TASKS = set()
...
@@ -46,7 +47,7 @@ ALL_TASKS = set()
func2task_index
=
{}
func2task_index
=
{}
def
register_task
(
name
):
def
register_task
(
name
:
str
):
def
decorate
(
fn
):
def
decorate
(
fn
):
assert
name
not
in
TASK_REGISTRY
,
(
assert
name
not
in
TASK_REGISTRY
,
(
f
"task named '
{
name
}
' conflicts with existing registered task!"
f
"task named '
{
name
}
' conflicts with existing registered task!"
...
@@ -120,7 +121,7 @@ def register_metric(**args):
...
@@ -120,7 +121,7 @@ def register_metric(**args):
return
decorate
return
decorate
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
)
->
Callable
:
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
)
->
Optional
[
Callable
]
:
if
not
hf_evaluate_metric
:
if
not
hf_evaluate_metric
:
if
name
in
METRIC_REGISTRY
:
if
name
in
METRIC_REGISTRY
:
return
METRIC_REGISTRY
[
name
]
return
METRIC_REGISTRY
[
name
]
...
@@ -130,6 +131,8 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
...
@@ -130,6 +131,8 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
)
)
try
:
try
:
import
evaluate
as
hf_evaluate
metric_object
=
hf_evaluate
.
load
(
name
)
metric_object
=
hf_evaluate
.
load
(
name
)
return
metric_object
.
compute
return
metric_object
.
compute
except
Exception
:
except
Exception
:
...
@@ -150,21 +153,21 @@ def register_aggregation(name: str):
...
@@ -150,21 +153,21 @@ def register_aggregation(name: str):
return
decorate
return
decorate
def
get_aggregation
(
name
:
str
)
->
Callable
[[],
Dict
[
str
,
Callable
]]:
def
get_aggregation
(
name
:
str
)
->
Optional
[
Callable
[[],
Dict
[
str
,
Callable
]]
]
:
try
:
try
:
return
AGGREGATION_REGISTRY
[
name
]
return
AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
not a registered aggregation metric!"
)
eval_logger
.
warning
(
f
"
{
name
}
not a registered aggregation metric!"
)
def
get_metric_aggregation
(
name
:
str
)
->
Callable
[[],
Dict
[
str
,
Callable
]]:
def
get_metric_aggregation
(
name
:
str
)
->
Optional
[
Callable
[[],
Dict
[
str
,
Callable
]]
]
:
try
:
try
:
return
METRIC_AGGREGATION_REGISTRY
[
name
]
return
METRIC_AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!"
)
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!"
)
def
is_higher_better
(
metric_name
)
->
bool
:
def
is_higher_better
(
metric_name
)
->
Optional
[
bool
]
:
try
:
try
:
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
except
KeyError
:
except
KeyError
:
...
...
lm_eval/api/samplers.py
View file @
1b5c6f88
import
logging
import
logging
import
warnings
import
warnings
from
functools
import
partial
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Iterable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Iterable
,
Optional
,
Sequence
,
Union
import
datasets
import
datasets
...
@@ -181,7 +181,7 @@ class ContextSampler:
...
@@ -181,7 +181,7 @@ class ContextSampler:
return
chat_history
return
chat_history
def
sample
(
self
,
n
:
int
):
def
sample
(
self
,
n
:
int
)
->
Sequence
[
dict
]
:
"""
"""
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
"""
"""
...
@@ -190,7 +190,7 @@ class ContextSampler:
...
@@ -190,7 +190,7 @@ class ContextSampler:
class
FirstNSampler
(
ContextSampler
):
class
FirstNSampler
(
ContextSampler
):
def
sample
(
self
,
n
:
int
)
->
None
:
def
sample
(
self
,
n
:
int
)
->
Sequence
[
dict
]
:
"""
"""
Draw the first `n` samples in order from the specified split.
Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
...
...
lm_eval/api/task.py
View file @
1b5c6f88
...
@@ -6,6 +6,7 @@ import re
...
@@ -6,6 +6,7 @@ import re
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
copy
import
deepcopy
from
copy
import
deepcopy
from
dataclasses
import
asdict
,
dataclass
from
dataclasses
import
asdict
,
dataclass
from
functools
import
cached_property
from
inspect
import
getsource
from
inspect
import
getsource
from
typing
import
(
from
typing
import
(
Any
,
Any
,
...
@@ -23,6 +24,7 @@ from typing import (
...
@@ -23,6 +24,7 @@ from typing import (
import
datasets
import
datasets
import
numpy
as
np
import
numpy
as
np
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
typing_extensions
import
deprecated
from
lm_eval
import
utils
from
lm_eval
import
utils
from
lm_eval.api
import
samplers
from
lm_eval.api
import
samplers
...
@@ -51,6 +53,43 @@ ALL_OUTPUT_TYPES = [
...
@@ -51,6 +53,43 @@ ALL_OUTPUT_TYPES = [
eval_logger
=
logging
.
getLogger
(
__name__
)
eval_logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
MetricConfig
:
"""Encapsulates information about a single metric."""
name
:
str
fn
:
Optional
[
Callable
]
=
None
kwargs
:
Optional
[
dict
]
=
None
aggregation_fn
:
Optional
[
Callable
]
=
None
higher_is_better
:
bool
=
True
hf_evaluate
:
bool
=
False
@
cached_property
def
metric_names
(
self
)
->
str
:
return
self
.
name
@
cached_property
def
aggregation
(
self
)
->
Callable
:
if
self
.
aggregation_fn
is
None
:
return
get_aggregation
(
self
.
name
)
return
self
.
aggregation_fn
@
cached_property
def
_higher_is_better
(
self
)
->
bool
:
if
self
.
higher_is_better
is
None
:
return
is_higher_better
(
self
.
name
)
return
self
.
higher_is_better
@
dataclass
class
FilterConfig
:
"""Encapsulates information about a filter."""
name
:
str
fn
:
Optional
[
Callable
]
=
None
kwargs
:
Optional
[
dict
]
=
None
@
dataclass
@
dataclass
class
TaskConfig
(
dict
):
class
TaskConfig
(
dict
):
# task naming/registry
# task naming/registry
...
@@ -133,6 +172,93 @@ class TaskConfig(dict):
...
@@ -133,6 +172,93 @@ class TaskConfig(dict):
f
"
{
self
.
task
}
: No `generation_kwargs` specified in task config, defaulting to
{
self
.
generation_kwargs
}
"
f
"
{
self
.
task
}
: No `generation_kwargs` specified in task config, defaulting to
{
self
.
generation_kwargs
}
"
)
)
if
self
.
metric_list
is
not
None
:
for
metric_config
in
self
.
metric_list
:
if
"metric"
not
in
metric_config
:
raise
ValueError
(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
def
get_metrics
(
self
)
->
list
[
"MetricConfig"
]:
metrics
=
[]
if
self
.
metric_list
is
None
:
_metric_list
=
DEFAULT_METRIC_REGISTRY
[
self
.
output_type
]
metrics
.
extend
(
MetricConfig
(
name
=
metric_name
,
fn
=
get_metric
(
metric_name
),
aggregation_fn
=
get_metric_aggregation
(
metric_name
),
higher_is_better
=
is_higher_better
(
metric_name
),
)
for
metric_name
in
_metric_list
)
else
:
for
metric_config
in
self
.
metric_list
:
if
"metric"
not
in
metric_config
:
raise
ValueError
(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name
=
metric_config
[
"metric"
]
_metric_fn_kwargs
=
{
key
:
metric_config
[
key
]
for
key
in
metric_config
if
key
not
in
[
"metric"
,
"aggregation"
,
"higher_is_better"
,
"hf_evaluate"
]
}
_hf_evaluate_metric
:
bool
=
metric_config
.
get
(
"hf_evaluate"
,
False
)
_metric_fn
=
None
_aggregation
=
None
if
self
.
process_results
is
not
None
:
# User will compute metrics inside `process_results()`
_metric_name
=
None
_metric_fn_kwargs
=
{}
elif
callable
(
metric_name
):
# User passed a function object
_metric_name
=
metric_name
.
__name__
_metric_fn
=
metric_name
.
__call__
else
:
# Normal: look up by name
_metric_name
=
get_metric
(
metric_name
,
_hf_evaluate_metric
)
# ---------- 3. Decide how to aggregate examples ----------
if
"aggregation"
in
metric_config
:
if
isinstance
(
_agg_name
:
=
metric_config
[
"aggregation"
],
str
):
_aggregation
=
get_aggregation
(
_agg_name
)
elif
callable
(
_agg_name
):
# noqa: E721
_aggregation
=
metric_config
[
"aggregation"
]
else
:
INV_AGG_REGISTRY
=
{
v
:
k
for
k
,
v
in
AGGREGATION_REGISTRY
.
items
()}
_aggregation
=
get_metric_aggregation
(
metric_name
)
eval_logger
.
warning
(
f
"[Task:
{
self
.
task
}
] metric
{
metric_name
}
is defined, but aggregation is not. "
f
"using default "
f
"aggregation=
{
INV_AGG_REGISTRY
[
_aggregation
]
}
"
)
# ---------- 4. Determine “higher-is-better” semantics ----------
if
"higher_is_better"
in
metric_config
:
_higher_is_better
=
metric_config
[
"higher_is_better"
]
else
:
eval_logger
.
warning
(
f
"[Task:
{
self
.
task
}
] metric
{
metric_name
}
is defined, but higher_is_better is not. "
f
"using default "
f
"higher_is_better=
{
is_higher_better
(
metric_name
)
}
"
)
_higher_is_better
=
is_higher_better
(
metric_name
)
metrics
.
append
(
MetricConfig
(
name
=
_metric_name
,
fn
=
_metric_fn
,
kwargs
=
_metric_fn_kwargs
,
aggregation_fn
=
_aggregation
,
higher_is_better
=
_higher_is_better
,
hf_evaluate
=
_hf_evaluate_metric
,
)
)
return
metrics
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
return
getattr
(
self
,
item
)
...
@@ -534,7 +660,7 @@ class Task(abc.ABC):
...
@@ -534,7 +660,7 @@ class Task(abc.ABC):
"""
"""
pass
pass
@
abc
.
abstractmethod
@
deprecated
(
"not used anymore"
)
def
aggregation
(
self
):
def
aggregation
(
self
):
"""
"""
:returns: {str: [metric_score] -> float}
:returns: {str: [metric_score] -> float}
...
@@ -543,7 +669,7 @@ class Task(abc.ABC):
...
@@ -543,7 +669,7 @@ class Task(abc.ABC):
"""
"""
pass
pass
@
abc
.
abstractmethod
@
deprecated
(
"not used anymore"
)
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
"""
"""
:returns: {str: bool}
:returns: {str: bool}
...
@@ -661,23 +787,13 @@ class Task(abc.ABC):
...
@@ -661,23 +787,13 @@ class Task(abc.ABC):
Parameters:
Parameters:
- metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
- metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
"""
"""
(
# if not isinstance(self, ConfigurableTask):
self
.
_metric_fn_list
,
# self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
self
.
_aggregation_list
,
# self.aggregation = lambda: {
self
.
_metric_fn_kwargs
,
# metric_name: get_metric_aggregation(metric_name)
self
.
_higher_is_better
,
# }
)
=
({},
{},
{},
{})
setattr
(
self
.
_config
,
"metric_list"
,
[
MetricConfig
(
name
=
metric_name
)])
self
.
_metric_fn_list
[
metric_name
]
=
get_metric
(
metric_name
)
setattr
(
self
.
_config
,
"process_results"
,
lambda
*
args
:
{
"bypass"
:
0
})
self
.
_aggregation_list
[
metric_name
]
=
get_metric_aggregation
(
metric_name
)
self
.
_higher_is_better
[
metric_name
]
=
is_higher_better
(
metric_name
)
self
.
_metric_fn_kwargs
[
metric_name
]
=
{}
if
not
isinstance
(
self
,
ConfigurableTask
):
self
.
process_results
=
lambda
x
,
y
:
{
metric_name
:
get_metric
(
metric_name
)}
self
.
aggregation
=
lambda
:
{
metric_name
:
get_metric_aggregation
(
metric_name
)
}
setattr
(
self
.
_config
,
"metric_list"
,
[{
"metric"
:
metric_name
}])
setattr
(
self
.
_config
,
"process_results"
,
None
)
def
set_fewshot_seed
(
self
,
seed
:
Optional
[
int
]
=
None
)
->
None
:
def
set_fewshot_seed
(
self
,
seed
:
Optional
[
int
]
=
None
)
->
None
:
self
.
fewshot_rnd
=
random
.
Random
(
seed
)
self
.
fewshot_rnd
=
random
.
Random
(
seed
)
...
@@ -739,7 +855,7 @@ class ConfigurableTask(Task):
...
@@ -739,7 +855,7 @@ class ConfigurableTask(Task):
cache_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
download_mode
=
None
,
config
:
Optional
[
dict
]
=
None
,
config
:
Optional
[
dict
]
=
None
,
)
->
None
:
# TODO no super() call here
)
->
None
:
# Get pre-configured attributes
# Get pre-configured attributes
self
.
_config
=
self
.
CONFIG
self
.
_config
=
self
.
CONFIG
...
@@ -784,83 +900,7 @@ class ConfigurableTask(Task):
...
@@ -784,83 +900,7 @@ class ConfigurableTask(Task):
if
self
.
config
.
dataset_name
is
not
None
:
if
self
.
config
.
dataset_name
is
not
None
:
self
.
DATASET_NAME
=
self
.
config
.
dataset_name
self
.
DATASET_NAME
=
self
.
config
.
dataset_name
self
.
_metric_fn_list
=
{}
self
.
metric_list
:
list
[
MetricConfig
]
=
self
.
_config
.
get_metrics
()
self
.
_metric_fn_kwargs
=
{}
self
.
_aggregation_list
=
{}
self
.
_higher_is_better
=
{}
if
self
.
config
.
metric_list
is
None
:
# TODO: handle this in TaskConfig.__post_init__ ?
_metric_list
=
DEFAULT_METRIC_REGISTRY
[
self
.
config
.
output_type
]
for
metric_name
in
_metric_list
:
self
.
_metric_fn_list
[
metric_name
]
=
get_metric
(
metric_name
)
self
.
_metric_fn_kwargs
[
metric_name
]
=
{}
self
.
_aggregation_list
[
metric_name
]
=
get_metric_aggregation
(
metric_name
)
self
.
_higher_is_better
[
metric_name
]
=
is_higher_better
(
metric_name
)
else
:
for
metric_config
in
self
.
config
.
metric_list
:
if
"metric"
not
in
metric_config
:
raise
ValueError
(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name
=
metric_config
[
"metric"
]
kwargs
=
{
key
:
metric_config
[
key
]
for
key
in
metric_config
if
key
not
in
[
"metric"
,
"aggregation"
,
"higher_is_better"
,
"hf_evaluate"
]
}
hf_evaluate_metric
=
(
"hf_evaluate"
in
metric_config
and
metric_config
[
"hf_evaluate"
]
is
True
)
if
self
.
config
.
process_results
is
not
None
:
self
.
_metric_fn_list
[
metric_name
]
=
None
self
.
_metric_fn_kwargs
[
metric_name
]
=
{}
elif
callable
(
metric_name
):
metric_fn
=
metric_name
.
__call__
metric_name
=
metric_name
.
__name__
self
.
_metric_fn_list
[
metric_name
]
=
metric_fn
self
.
_metric_fn_kwargs
[
metric_name
]
=
kwargs
else
:
self
.
_metric_fn_list
[
metric_name
]
=
get_metric
(
metric_name
,
hf_evaluate_metric
)
self
.
_metric_fn_kwargs
[
metric_name
]
=
kwargs
if
"aggregation"
in
metric_config
:
agg_name
=
metric_config
[
"aggregation"
]
if
isinstance
(
agg_name
,
str
):
self
.
_aggregation_list
[
metric_name
]
=
get_aggregation
(
agg_name
)
elif
callable
(
agg_name
):
# noqa: E721
self
.
_aggregation_list
[
metric_name
]
=
metric_config
[
"aggregation"
]
else
:
INV_AGG_REGISTRY
=
{
v
:
k
for
k
,
v
in
AGGREGATION_REGISTRY
.
items
()}
metric_agg
=
get_metric_aggregation
(
metric_name
)
eval_logger
.
warning
(
f
"[Task:
{
self
.
config
.
task
}
] metric
{
metric_name
}
is defined, but aggregation is not. "
f
"using default "
f
"aggregation=
{
INV_AGG_REGISTRY
[
metric_agg
]
}
"
)
self
.
_aggregation_list
[
metric_name
]
=
metric_agg
if
"higher_is_better"
in
metric_config
:
self
.
_higher_is_better
[
metric_name
]
=
metric_config
[
"higher_is_better"
]
else
:
eval_logger
.
warning
(
f
"[Task:
{
self
.
config
.
task
}
] metric
{
metric_name
}
is defined, but higher_is_better is not. "
f
"using default "
f
"higher_is_better=
{
is_higher_better
(
metric_name
)
}
"
)
self
.
_higher_is_better
[
metric_name
]
=
is_higher_better
(
metric_name
)
self
.
download
(
self
.
config
.
dataset_kwargs
)
self
.
download
(
self
.
config
.
dataset_kwargs
)
self
.
_training_docs
=
None
self
.
_training_docs
=
None
...
@@ -868,17 +908,23 @@ class ConfigurableTask(Task):
...
@@ -868,17 +908,23 @@ class ConfigurableTask(Task):
if
self
.
config
.
filter_list
is
not
None
:
if
self
.
config
.
filter_list
is
not
None
:
self
.
_filters
=
[]
self
.
_filters
=
[]
for
filter_config
in
self
.
config
.
filter_list
:
if
isinstance
(
self
.
config
.
filter_list
,
dict
):
filter_name
=
filter_config
[
"name"
]
for
filter_config
in
self
.
config
.
filter_list
:
filter_functions
=
filter_config
[
"filter"
]
self
.
_filters
.
append
(
components
=
[]
build_filter_ensemble
(
for
function
in
filter_functions
:
filter_config
[
"name"
],
kwargs
=
{
[
key
:
function
[
key
]
for
key
in
function
if
key
!=
"function"
[
}
{
components
.
append
([
function
[
"function"
],
kwargs
])
key
:
function
[
key
]
filter_pipeline
=
build_filter_ensemble
(
filter_name
,
components
)
for
key
in
function
self
.
_filters
.
append
(
filter_pipeline
)
if
key
!=
"function"
}
]
for
function
in
filter_config
[
"filter"
]
],
)
)
else
:
else
:
# TODO: handle repeats in a more general way rather than just discarding
# TODO: handle repeats in a more general way rather than just discarding
eval_logger
.
debug
(
eval_logger
.
debug
(
...
@@ -1476,7 +1522,7 @@ class ConfigurableTask(Task):
...
@@ -1476,7 +1522,7 @@ class ConfigurableTask(Task):
arguments
=
[(
ctx
,
f
"
{
target_delimiter
}{
cont
}
"
)
for
cont
in
choices
]
arguments
=
[(
ctx
,
f
"
{
target_delimiter
}{
cont
}
"
)
for
cont
in
choices
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if
"acc_mutual_info"
in
self
.
_
metric_
fn_
list
.
keys
()
:
if
"acc_mutual_info"
in
[
m
.
metric_names
for
m
in
self
.
metric_list
]
:
# if we are calculating multiple choice accuracy
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
...
@@ -1543,7 +1589,7 @@ class ConfigurableTask(Task):
...
@@ -1543,7 +1589,7 @@ class ConfigurableTask(Task):
return
self
.
config
.
process_results
(
doc
,
results
)
return
self
.
config
.
process_results
(
doc
,
results
)
result_dict
=
{}
result_dict
=
{}
use_metric
=
list
(
self
.
_
metric_
fn_
list
.
keys
()
)
use_metric
=
list
(
m
.
metric_names
for
m
in
self
.
metric_list
)
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
results
=
results
[
0
]
results
=
results
[
0
]
ll
,
is_greedy
=
results
ll
,
is_greedy
=
results
...
@@ -1579,10 +1625,7 @@ class ConfigurableTask(Task):
...
@@ -1579,10 +1625,7 @@ class ConfigurableTask(Task):
choices
=
self
.
doc_to_choice
(
doc
)
choices
=
self
.
doc_to_choice
(
doc
)
completion_len
=
np
.
array
([
float
(
len
(
i
))
for
i
in
choices
])
completion_len
=
np
.
array
([
float
(
len
(
i
))
for
i
in
choices
])
if
(
if
2
*
len
(
choices
)
==
len
(
lls
)
and
"acc_mutual_info"
in
use_metric
:
2
*
len
(
choices
)
==
len
(
lls
)
and
"acc_mutual_info"
in
self
.
_metric_fn_list
.
keys
()
):
# then we are doing mutual info.
# then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods
# this stores the "dryrun" / unconditional answer loglikelihoods
# as we extend the args list with unconditional ("", continuation) pairs
# as we extend the args list with unconditional ("", continuation) pairs
...
@@ -1667,12 +1710,12 @@ class ConfigurableTask(Task):
...
@@ -1667,12 +1710,12 @@ class ConfigurableTask(Task):
gold
=
list
(
gold
)
gold
=
list
(
gold
)
# TODO: handle this better
# TODO: handle this better
elif
type
(
gold
)
is
not
type
(
result
)
and
not
(
elif
type
(
gold
)
is
not
type
(
result
)
and
not
(
"bypass"
in
se
lf
.
_metric
_fn_list
.
keys
()
or
isinstance
(
result
,
list
)
"bypass"
in
u
se_metric
or
isinstance
(
result
,
list
)
):
):
# cast gold to the same type as result
# cast gold to the same type as result
gold
=
type
(
result
)(
gold
)
gold
=
type
(
result
)(
gold
)
for
metric
in
self
.
_
metric_
fn_
list
.
keys
()
:
for
metric
in
self
.
metric_list
:
if
self
.
multiple_target
:
if
self
.
multiple_target
:
# in the case where we have multiple targets,
# in the case where we have multiple targets,
# return true if any are true
# return true if any are true
...
@@ -1682,28 +1725,26 @@ class ConfigurableTask(Task):
...
@@ -1682,28 +1725,26 @@ class ConfigurableTask(Task):
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
# print(gold)
# print(gold)
gold
=
[
gold
]
gold
=
[
gold
]
if
metric
==
"exact_match"
:
if
metric
.
name
==
"exact_match"
:
result
=
[
result
for
_
in
range
(
len
(
gold
))]
result
=
[
result
for
_
in
range
(
len
(
gold
))]
scores
=
self
.
_
metric
_
fn
_list
[
metric
]
(
scores
=
metric
.
fn
(
references
=
gold
,
references
=
gold
,
predictions
=
result
,
predictions
=
result
,
**
self
.
_
metric
_fn_
kwargs
[
metric
]
,
**
metric
.
kwargs
,
)[
metric
]
)[
metric
]
result_score
=
1.0
if
scores
>
0.0
else
0.0
result_score
=
1.0
if
scores
>
0.0
else
0.0
else
:
else
:
for
gold_option
in
gold
:
for
gold_option
in
gold
:
try
:
try
:
result_score
=
self
.
_
metric
_
fn
_list
[
metric
]
(
result_score
=
metric
.
fn
(
references
=
[
gold_option
],
references
=
[
gold_option
],
predictions
=
[
result
],
predictions
=
[
result
],
**
self
.
_
metric
_fn_
kwargs
[
metric
]
,
**
metric
.
kwargs
,
)
)
except
(
except
(
TypeError
TypeError
):
# TODO: this is hacky and I don't want to do it
):
# TODO: this is hacky and I don't want to do it
result_score
=
self
.
_metric_fn_list
[
metric
](
result_score
=
metric
.
fn
([
gold_option
,
result
])
[
gold_option
,
result
]
)
if
isinstance
(
result_score
,
dict
):
if
isinstance
(
result_score
,
dict
):
# TODO: this handles the case where HF evaluate returns a dict.
# TODO: this handles the case where HF evaluate returns a dict.
result_score
=
result_score
[
metric
]
result_score
=
result_score
[
metric
]
...
@@ -1714,13 +1755,13 @@ class ConfigurableTask(Task):
...
@@ -1714,13 +1755,13 @@ class ConfigurableTask(Task):
result_score
=
0.0
result_score
=
0.0
else
:
else
:
try
:
try
:
result_score
=
self
.
_
metric
_
fn
_list
[
metric
]
(
result_score
=
metric
.
fn
(
references
=
[
gold
],
references
=
[
gold
],
predictions
=
[
result
],
predictions
=
[
result
],
**
self
.
_
metric
_fn_
kwargs
[
metric
]
,
**
metric
.
kwargs
,
)
)
except
TypeError
:
# needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
except
TypeError
:
# needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score
=
self
.
_
metric
_
fn
_list
[
metric
]
([
gold
,
result
])
result_score
=
metric
.
fn
([
gold
,
result
])
if
isinstance
(
result_score
,
dict
):
if
isinstance
(
result_score
,
dict
):
# TODO: this handles the case where HF evaluate returns a dict.
# TODO: this handles the case where HF evaluate returns a dict.
# This allows for multiple metrics to be returned from the same function
# This allows for multiple metrics to be returned from the same function
...
@@ -1737,10 +1778,10 @@ class ConfigurableTask(Task):
...
@@ -1737,10 +1778,10 @@ class ConfigurableTask(Task):
return
result_dict
return
result_dict
def
aggregation
(
self
)
->
dict
:
def
aggregation
(
self
)
->
dict
:
return
self
.
_
aggregation_list
return
{
k
.
name
:
k
.
aggregation_
fn
for
k
in
self
.
metric_
list
}
def
higher_is_better
(
self
)
->
dict
:
def
higher_is_better
(
self
)
->
dict
:
return
self
.
_
higher_is_better
return
{
k
.
name
:
k
.
higher_is_better
for
k
in
self
.
metric_list
}
def
get_config
(
self
,
key
:
str
)
->
Any
:
def
get_config
(
self
,
key
:
str
)
->
Any
:
return
getattr
(
self
.
_config
,
key
,
None
)
return
getattr
(
self
.
_config
,
key
,
None
)
...
...
lm_eval/evaluator.py
View file @
1b5c6f88
...
@@ -272,7 +272,7 @@ def simple_evaluate(
...
@@ -272,7 +272,7 @@ def simple_evaluate(
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
def
_adjust_config
(
task_dict
)
:
def
_adjust_config
(
task_dict
:
dict
[
str
,
"Task"
])
->
dict
[
str
,
"Task"
]
:
adjusted_task_dict
=
{}
adjusted_task_dict
=
{}
for
task_name
,
task_obj
in
task_dict
.
items
():
for
task_name
,
task_obj
in
task_dict
.
items
():
if
isinstance
(
task_obj
,
dict
):
if
isinstance
(
task_obj
,
dict
):
...
...
lm_eval/evaluator_utils.py
View file @
1b5c6f88
...
@@ -12,7 +12,7 @@ from lm_eval.api.metrics import (
...
@@ -12,7 +12,7 @@ from lm_eval.api.metrics import (
pooled_sample_stderr
,
pooled_sample_stderr
,
stderr_for_metric
,
stderr_for_metric
,
)
)
from
lm_eval.api.task
import
Task
from
lm_eval.api.task
import
ConfigurableTask
,
Task
from
lm_eval.utils
import
positional_deprecated
from
lm_eval.utils
import
positional_deprecated
...
@@ -58,7 +58,7 @@ class TaskOutput:
...
@@ -58,7 +58,7 @@ class TaskOutput:
group_alias
=
None
,
group_alias
=
None
,
is_group
=
None
,
is_group
=
None
,
):
):
self
.
task
=
task
self
.
task
:
Union
[
Task
,
ConfigurableTask
]
=
task
self
.
task_config
=
task_config
self
.
task_config
=
task_config
self
.
task_name
=
task_name
self
.
task_name
=
task_name
self
.
group_name
=
group_name
self
.
group_name
=
group_name
...
...
lm_eval/filters/__init__.py
View file @
1b5c6f88
from
functools
import
partial
from
functools
import
partial
from
typing
import
List
from
typing
import
List
,
Union
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.registry
import
get_filter
from
lm_eval.api.registry
import
get_filter
...
@@ -8,7 +8,7 @@ from . import custom, extraction, selection, transformation
...
@@ -8,7 +8,7 @@ from . import custom, extraction, selection, transformation
def
build_filter_ensemble
(
def
build_filter_ensemble
(
filter_name
:
str
,
components
:
List
[
L
ist
[
str
]]
filter_name
:
str
,
components
:
list
[
Union
[
list
[
dict
],
l
ist
[
str
]]
]
)
->
FilterEnsemble
:
)
->
FilterEnsemble
:
"""
"""
Create a filtering pipeline.
Create a filtering pipeline.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment