Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
28c78d30
"test/verify/test_reduce_noop_add.cpp" did not exist on "8d21fdc9dd58e62192d9408132585eea94bbf79b"
Commit
28c78d30
authored
Jun 30, 2025
by
Baber
Browse files
add MetricConfig
parent
de496b80
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
223 additions
and
156 deletions
+223
-156
lm_eval/__main__.py
lm_eval/__main__.py
+4
-0
lm_eval/api/group.py
lm_eval/api/group.py
+1
-2
lm_eval/api/instance.py
lm_eval/api/instance.py
+17
-4
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+6
-1
lm_eval/api/registry.py
lm_eval/api/registry.py
+13
-10
lm_eval/api/samplers.py
lm_eval/api/samplers.py
+3
-3
lm_eval/api/task.py
lm_eval/api/task.py
+174
-131
lm_eval/evaluator.py
lm_eval/evaluator.py
+1
-1
lm_eval/evaluator_utils.py
lm_eval/evaluator_utils.py
+2
-2
lm_eval/filters/__init__.py
lm_eval/filters/__init__.py
+2
-2
No files found.
lm_eval/__main__.py
View file @
28c78d30
...
...
@@ -485,6 +485,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if
results
is
not
None
:
if
args
.
log_samples
:
samples
=
results
.
pop
(
"samples"
)
# TODO: fix this!
results
[
"higher_is_better"
]
=
{
k
:
True
for
k
,
v
in
results
[
"higher_is_better"
].
items
()
}
dumped
=
json
.
dumps
(
results
,
indent
=
2
,
default
=
handle_non_serializable
,
ensure_ascii
=
False
)
...
...
lm_eval/api/group.py
View file @
28c78d30
import
abc
from
dataclasses
import
asdict
,
dataclass
from
inspect
import
getsource
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
...
...
@@ -84,7 +83,7 @@ class GroupConfig(dict):
return
str
(
value
)
class
ConfigurableGroup
(
abc
.
ABC
)
:
class
ConfigurableGroup
:
def
__init__
(
self
,
config
:
Optional
[
dict
]
=
None
,
...
...
lm_eval/api/instance.py
View file @
28c78d30
...
...
@@ -14,10 +14,23 @@ class Instance:
arguments
:
tuple
idx
:
int
metadata
:
Tuple
[
Optional
[
str
],
Optional
[
int
],
Optional
[
int
]]
=
field
(
default_factory
=
lambda
:
(
None
,
None
,
None
)
default_factory
=
lambda
:
(
None
,
None
,
None
),
metadata
=
dict
(
description
=
"Metadata tuple containing task name, document ID, and number of repeats."
),
)
resps
:
list
=
field
(
default_factory
=
list
,
metadata
=
dict
(
description
=
"List of responses from the model for this instance."
),
)
filtered_resps
:
dict
=
field
(
default_factory
=
dict
,
metadata
=
dict
(
description
=
"List of filtered responses for this instance, keyed by filter name."
),
)
resps
:
list
=
field
(
default_factory
=
list
)
filtered_resps
:
dict
=
field
(
default_factory
=
dict
)
# initialized after init
task_name
:
Optional
[
str
]
=
None
...
...
@@ -29,7 +42,7 @@ class Instance:
self
.
task_name
,
self
.
doc_id
,
self
.
repeats
=
self
.
metadata
@
property
def
args
(
self
):
def
args
(
self
)
->
tuple
:
"""
Returns (string,) where `string` is the string to calculate loglikelihood over
"""
...
...
lm_eval/api/metrics.py
View file @
28c78d30
...
...
@@ -8,7 +8,6 @@ from collections.abc import Iterable
from
typing
import
Callable
,
List
,
Optional
,
Sequence
,
TypeVar
import
numpy
as
np
import
sacrebleu
from
lm_eval.api.registry
import
register_aggregation
,
register_metric
...
...
@@ -92,6 +91,8 @@ def bleu(items):
Higher is better
"""
import
sacrebleu
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
...
...
@@ -107,6 +108,8 @@ def chrf(items):
Higher is better # TODO I think
"""
import
sacrebleu
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
...
...
@@ -123,6 +126,8 @@ def ter(items):
Lower is better
"""
import
sacrebleu
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
...
...
lm_eval/api/registry.py
View file @
28c78d30
import
logging
from
typing
import
Callable
,
Dict
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
,
Dict
,
Optional
,
Union
import
evaluate
as
hf_evaluate
from
lm_eval.api.model
import
LM
if
TYPE_CHECKING
:
from
lm_eval.api.model
import
LM
eval_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -12,6 +11,8 @@ MODEL_REGISTRY = {}
def
register_model
(
*
names
):
from
lm_eval.api.model
import
LM
# either pass a list or a single alias.
# function receives them as a tuple of strings
...
...
@@ -31,7 +32,7 @@ def register_model(*names):
return
decorate
def
get_model
(
model_name
)
:
def
get_model
(
model_name
:
str
)
->
type
[
"LM"
]
:
try
:
return
MODEL_REGISTRY
[
model_name
]
except
KeyError
:
...
...
@@ -46,7 +47,7 @@ ALL_TASKS = set()
func2task_index
=
{}
def
register_task
(
name
):
def
register_task
(
name
:
str
):
def
decorate
(
fn
):
assert
name
not
in
TASK_REGISTRY
,
(
f
"task named '
{
name
}
' conflicts with existing registered task!"
...
...
@@ -120,7 +121,7 @@ def register_metric(**args):
return
decorate
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
)
->
Callable
:
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
)
->
Optional
[
Callable
]
:
if
not
hf_evaluate_metric
:
if
name
in
METRIC_REGISTRY
:
return
METRIC_REGISTRY
[
name
]
...
...
@@ -130,6 +131,8 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
)
try
:
import
evaluate
as
hf_evaluate
metric_object
=
hf_evaluate
.
load
(
name
)
return
metric_object
.
compute
except
Exception
:
...
...
@@ -150,21 +153,21 @@ def register_aggregation(name: str):
return
decorate
def
get_aggregation
(
name
:
str
)
->
Callable
[[],
Dict
[
str
,
Callable
]]:
def
get_aggregation
(
name
:
str
)
->
Optional
[
Callable
[[],
Dict
[
str
,
Callable
]]
]
:
try
:
return
AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
not a registered aggregation metric!"
)
def
get_metric_aggregation
(
name
:
str
)
->
Callable
[[],
Dict
[
str
,
Callable
]]:
def
get_metric_aggregation
(
name
:
str
)
->
Optional
[
Callable
[[],
Dict
[
str
,
Callable
]]
]
:
try
:
return
METRIC_AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!"
)
def
is_higher_better
(
metric_name
)
->
bool
:
def
is_higher_better
(
metric_name
)
->
Optional
[
bool
]
:
try
:
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
except
KeyError
:
...
...
lm_eval/api/samplers.py
View file @
28c78d30
import
logging
import
warnings
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Iterable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Iterable
,
Optional
,
Sequence
,
Union
import
datasets
...
...
@@ -181,7 +181,7 @@ class ContextSampler:
return
chat_history
def
sample
(
self
,
n
:
int
):
def
sample
(
self
,
n
:
int
)
->
Sequence
[
dict
]
:
"""
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
"""
...
...
@@ -190,7 +190,7 @@ class ContextSampler:
class
FirstNSampler
(
ContextSampler
):
def
sample
(
self
,
n
:
int
)
->
None
:
def
sample
(
self
,
n
:
int
)
->
Sequence
[
dict
]
:
"""
Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
...
...
lm_eval/api/task.py
View file @
28c78d30
...
...
@@ -6,6 +6,7 @@ import re
from
collections.abc
import
Callable
from
copy
import
deepcopy
from
dataclasses
import
asdict
,
dataclass
from
functools
import
cached_property
from
inspect
import
getsource
from
typing
import
(
Any
,
...
...
@@ -23,6 +24,7 @@ from typing import (
import
datasets
import
numpy
as
np
from
tqdm
import
tqdm
from
typing_extensions
import
deprecated
from
lm_eval
import
utils
from
lm_eval.api
import
samplers
...
...
@@ -51,6 +53,43 @@ ALL_OUTPUT_TYPES = [
eval_logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
MetricConfig
:
"""Encapsulates information about a single metric."""
name
:
str
fn
:
Optional
[
Callable
]
=
None
kwargs
:
Optional
[
dict
]
=
None
aggregation_fn
:
Optional
[
Callable
]
=
None
higher_is_better
:
bool
=
True
hf_evaluate
:
bool
=
False
@
cached_property
def
metric_names
(
self
)
->
str
:
return
self
.
name
@
cached_property
def
aggregation
(
self
)
->
Callable
:
if
self
.
aggregation_fn
is
None
:
return
get_aggregation
(
self
.
name
)
return
self
.
aggregation_fn
@
cached_property
def
_higher_is_better
(
self
)
->
bool
:
if
self
.
higher_is_better
is
None
:
return
is_higher_better
(
self
.
name
)
return
self
.
higher_is_better
@
dataclass
class
FilterConfig
:
"""Encapsulates information about a single filter."""
name
:
str
fn
:
Optional
[
Callable
]
=
None
kwargs
:
Optional
[
dict
]
=
None
@
dataclass
class
TaskConfig
(
dict
):
# task naming/registry
...
...
@@ -99,6 +138,8 @@ class TaskConfig(dict):
metadata
:
Optional
[
dict
]
=
(
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
)
_metric_list
=
None
_filter_list
=
None
def
__post_init__
(
self
)
->
None
:
if
self
.
generation_kwargs
is
not
None
:
...
...
@@ -133,6 +174,93 @@ class TaskConfig(dict):
f
"
{
self
.
task
}
: No `generation_kwargs` specified in task config, defaulting to
{
self
.
generation_kwargs
}
"
)
if
self
.
metric_list
is
not
None
:
for
metric_config
in
self
.
metric_list
:
if
"metric"
not
in
metric_config
:
raise
ValueError
(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
def
get_metrics
(
self
)
->
list
[
"MetricConfig"
]:
metrics
=
[]
if
self
.
metric_list
is
None
:
_metric_list
=
DEFAULT_METRIC_REGISTRY
[
self
.
output_type
]
metrics
.
extend
(
MetricConfig
(
name
=
metric_name
,
fn
=
get_metric
(
metric_name
),
aggregation_fn
=
get_metric_aggregation
(
metric_name
),
higher_is_better
=
is_higher_better
(
metric_name
),
)
for
metric_name
in
_metric_list
)
else
:
for
metric_config
in
self
.
metric_list
:
if
"metric"
not
in
metric_config
:
raise
ValueError
(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name
=
metric_config
[
"metric"
]
_metric_fn_kwargs
=
{
key
:
metric_config
[
key
]
for
key
in
metric_config
if
key
not
in
[
"metric"
,
"aggregation"
,
"higher_is_better"
,
"hf_evaluate"
]
}
_hf_evaluate_metric
:
bool
=
metric_config
.
get
(
"hf_evaluate"
,
False
)
_metric_fn
=
None
_aggregation
=
None
if
self
.
process_results
is
not
None
:
# User will compute metrics inside `process_results()`
_metric_name
=
None
_metric_fn_kwargs
=
{}
elif
callable
(
metric_name
):
# User passed a function object
_metric_name
=
metric_name
.
__name__
_metric_fn
=
metric_name
.
__call__
else
:
# Normal: look up by name
_metric_name
=
get_metric
(
metric_name
,
_hf_evaluate_metric
)
# ---------- 3. Decide how to aggregate examples ----------
if
"aggregation"
in
metric_config
:
if
isinstance
(
_agg_name
:
=
metric_config
[
"aggregation"
],
str
):
_aggregation
=
get_aggregation
(
_agg_name
)
elif
callable
(
_agg_name
):
# noqa: E721
_aggregation
=
metric_config
[
"aggregation"
]
else
:
INV_AGG_REGISTRY
=
{
v
:
k
for
k
,
v
in
AGGREGATION_REGISTRY
.
items
()}
_aggregation
=
get_metric_aggregation
(
metric_name
)
eval_logger
.
warning
(
f
"[Task:
{
self
.
task
}
] metric
{
metric_name
}
is defined, but aggregation is not. "
f
"using default "
f
"aggregation=
{
INV_AGG_REGISTRY
[
_aggregation
]
}
"
)
# ---------- 4. Determine “higher-is-better” semantics ----------
if
"higher_is_better"
in
metric_config
:
_higher_is_better
=
metric_config
[
"higher_is_better"
]
else
:
eval_logger
.
warning
(
f
"[Task:
{
self
.
task
}
] metric
{
metric_name
}
is defined, but higher_is_better is not. "
f
"using default "
f
"higher_is_better=
{
is_higher_better
(
metric_name
)
}
"
)
_higher_is_better
=
is_higher_better
(
metric_name
)
metrics
.
append
(
MetricConfig
(
name
=
_metric_name
,
fn
=
_metric_fn
,
kwargs
=
_metric_fn_kwargs
,
aggregation_fn
=
_aggregation
,
higher_is_better
=
_higher_is_better
,
hf_evaluate
=
_hf_evaluate_metric
,
)
)
return
metrics
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
...
...
@@ -534,7 +662,7 @@ class Task(abc.ABC):
"""
pass
@
abc
.
abstractmethod
@
deprecated
(
"not used anymore"
)
def
aggregation
(
self
):
"""
:returns: {str: [metric_score] -> float}
...
...
@@ -543,7 +671,7 @@ class Task(abc.ABC):
"""
pass
@
abc
.
abstractmethod
@
deprecated
(
"not used anymore"
)
def
higher_is_better
(
self
):
"""
:returns: {str: bool}
...
...
@@ -661,23 +789,13 @@ class Task(abc.ABC):
Parameters:
- metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
"""
(
self
.
_metric_fn_list
,
self
.
_aggregation_list
,
self
.
_metric_fn_kwargs
,
self
.
_higher_is_better
,
)
=
({},
{},
{},
{})
self
.
_metric_fn_list
[
metric_name
]
=
get_metric
(
metric_name
)
self
.
_aggregation_list
[
metric_name
]
=
get_metric_aggregation
(
metric_name
)
self
.
_higher_is_better
[
metric_name
]
=
is_higher_better
(
metric_name
)
self
.
_metric_fn_kwargs
[
metric_name
]
=
{}
if
not
isinstance
(
self
,
ConfigurableTask
):
self
.
process_results
=
lambda
x
,
y
:
{
metric_name
:
get_metric
(
metric_name
)}
self
.
aggregation
=
lambda
:
{
metric_name
:
get_metric_aggregation
(
metric_name
)
}
setattr
(
self
.
_config
,
"metric_list"
,
[{
"metric"
:
metric_name
}])
setattr
(
self
.
_config
,
"process_results"
,
None
)
# if not isinstance(self, ConfigurableTask):
# self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
# self.aggregation = lambda: {
# metric_name: get_metric_aggregation(metric_name)
# }
setattr
(
self
.
_config
,
"metric_list"
,
[
MetricConfig
(
name
=
metric_name
)])
setattr
(
self
.
_config
,
"process_results"
,
lambda
*
args
:
{
"bypass"
:
0
})
def
set_fewshot_seed
(
self
,
seed
:
Optional
[
int
]
=
None
)
->
None
:
self
.
fewshot_rnd
=
random
.
Random
(
seed
)
...
...
@@ -739,7 +857,7 @@ class ConfigurableTask(Task):
cache_dir
=
None
,
download_mode
=
None
,
config
:
Optional
[
dict
]
=
None
,
)
->
None
:
# TODO no super() call here
)
->
None
:
# Get pre-configured attributes
self
.
_config
=
self
.
CONFIG
...
...
@@ -784,83 +902,7 @@ class ConfigurableTask(Task):
if
self
.
config
.
dataset_name
is
not
None
:
self
.
DATASET_NAME
=
self
.
config
.
dataset_name
self
.
_metric_fn_list
=
{}
self
.
_metric_fn_kwargs
=
{}
self
.
_aggregation_list
=
{}
self
.
_higher_is_better
=
{}
if
self
.
config
.
metric_list
is
None
:
# TODO: handle this in TaskConfig.__post_init__ ?
_metric_list
=
DEFAULT_METRIC_REGISTRY
[
self
.
config
.
output_type
]
for
metric_name
in
_metric_list
:
self
.
_metric_fn_list
[
metric_name
]
=
get_metric
(
metric_name
)
self
.
_metric_fn_kwargs
[
metric_name
]
=
{}
self
.
_aggregation_list
[
metric_name
]
=
get_metric_aggregation
(
metric_name
)
self
.
_higher_is_better
[
metric_name
]
=
is_higher_better
(
metric_name
)
else
:
for
metric_config
in
self
.
config
.
metric_list
:
if
"metric"
not
in
metric_config
:
raise
ValueError
(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name
=
metric_config
[
"metric"
]
kwargs
=
{
key
:
metric_config
[
key
]
for
key
in
metric_config
if
key
not
in
[
"metric"
,
"aggregation"
,
"higher_is_better"
,
"hf_evaluate"
]
}
hf_evaluate_metric
=
(
"hf_evaluate"
in
metric_config
and
metric_config
[
"hf_evaluate"
]
is
True
)
if
self
.
config
.
process_results
is
not
None
:
self
.
_metric_fn_list
[
metric_name
]
=
None
self
.
_metric_fn_kwargs
[
metric_name
]
=
{}
elif
callable
(
metric_name
):
metric_fn
=
metric_name
.
__call__
metric_name
=
metric_name
.
__name__
self
.
_metric_fn_list
[
metric_name
]
=
metric_fn
self
.
_metric_fn_kwargs
[
metric_name
]
=
kwargs
else
:
self
.
_metric_fn_list
[
metric_name
]
=
get_metric
(
metric_name
,
hf_evaluate_metric
)
self
.
_metric_fn_kwargs
[
metric_name
]
=
kwargs
if
"aggregation"
in
metric_config
:
agg_name
=
metric_config
[
"aggregation"
]
if
isinstance
(
agg_name
,
str
):
self
.
_aggregation_list
[
metric_name
]
=
get_aggregation
(
agg_name
)
elif
callable
(
agg_name
):
# noqa: E721
self
.
_aggregation_list
[
metric_name
]
=
metric_config
[
"aggregation"
]
else
:
INV_AGG_REGISTRY
=
{
v
:
k
for
k
,
v
in
AGGREGATION_REGISTRY
.
items
()}
metric_agg
=
get_metric_aggregation
(
metric_name
)
eval_logger
.
warning
(
f
"[Task:
{
self
.
config
.
task
}
] metric
{
metric_name
}
is defined, but aggregation is not. "
f
"using default "
f
"aggregation=
{
INV_AGG_REGISTRY
[
metric_agg
]
}
"
)
self
.
_aggregation_list
[
metric_name
]
=
metric_agg
if
"higher_is_better"
in
metric_config
:
self
.
_higher_is_better
[
metric_name
]
=
metric_config
[
"higher_is_better"
]
else
:
eval_logger
.
warning
(
f
"[Task:
{
self
.
config
.
task
}
] metric
{
metric_name
}
is defined, but higher_is_better is not. "
f
"using default "
f
"higher_is_better=
{
is_higher_better
(
metric_name
)
}
"
)
self
.
_higher_is_better
[
metric_name
]
=
is_higher_better
(
metric_name
)
self
.
metric_list
:
list
[
MetricConfig
]
=
self
.
_config
.
get_metrics
()
self
.
download
(
self
.
config
.
dataset_kwargs
)
self
.
_training_docs
=
None
...
...
@@ -868,17 +910,23 @@ class ConfigurableTask(Task):
if
self
.
config
.
filter_list
is
not
None
:
self
.
_filters
=
[]
for
filter_config
in
self
.
config
.
filter_list
:
filter_name
=
filter_config
[
"name"
]
filter_functions
=
filter_config
[
"filter"
]
components
=
[]
for
function
in
filter_functions
:
kwargs
=
{
key
:
function
[
key
]
for
key
in
function
if
key
!=
"function"
}
components
.
append
([
function
[
"function"
],
kwargs
])
filter_pipeline
=
build_filter_ensemble
(
filter_name
,
components
)
self
.
_filters
.
append
(
filter_pipeline
)
if
isinstance
(
self
.
config
.
filter_list
,
dict
):
for
filter_config
in
self
.
config
.
filter_list
:
self
.
_filters
.
append
(
build_filter_ensemble
(
filter_config
[
"name"
],
[
[
{
key
:
function
[
key
]
for
key
in
function
if
key
!=
"function"
}
]
for
function
in
filter_config
[
"filter"
]
],
)
)
else
:
# TODO: handle repeats in a more general way rather than just discarding
eval_logger
.
debug
(
...
...
@@ -1297,7 +1345,7 @@ class ConfigurableTask(Task):
return
doc
[
doc_to_text
]
else
:
text_string
=
utils
.
apply_template
(
doc_to_text
,
doc
)
if
text_string
.
isdigit
()
and
self
.
_
config
.
doc_to_choice
is
not
None
:
if
text_string
.
isdigit
()
and
self
.
config
.
doc_to_choice
is
not
None
:
return
ast
.
literal_eval
(
text_string
)
else
:
return
text_string
...
...
@@ -1333,7 +1381,7 @@ class ConfigurableTask(Task):
return
doc
[
doc_to_target
]
else
:
target_string
=
utils
.
apply_template
(
doc_to_target
,
doc
)
if
target_string
.
isdigit
()
and
self
.
_
config
.
doc_to_choice
is
not
None
:
if
target_string
.
isdigit
()
and
self
.
config
.
doc_to_choice
is
not
None
:
return
ast
.
literal_eval
(
target_string
)
elif
(
len
(
target_string
)
>=
2
...
...
@@ -1480,7 +1528,7 @@ class ConfigurableTask(Task):
arguments
=
[(
ctx
,
f
"
{
target_delimiter
}{
cont
}
"
)
for
cont
in
choices
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if
"acc_mutual_info"
in
self
.
_
metric_
fn_
list
.
keys
()
:
if
"acc_mutual_info"
in
[
m
.
metric_names
for
m
in
self
.
metric_list
]
:
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
...
...
@@ -1547,7 +1595,7 @@ class ConfigurableTask(Task):
return
self
.
config
.
process_results
(
doc
,
results
)
result_dict
=
{}
use_metric
=
list
(
self
.
_
metric_
fn_
list
.
keys
()
)
use_metric
=
list
(
m
.
metric_names
for
m
in
self
.
metric_list
)
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
results
=
results
[
0
]
ll
,
is_greedy
=
results
...
...
@@ -1583,10 +1631,7 @@ class ConfigurableTask(Task):
choices
=
self
.
doc_to_choice
(
doc
)
completion_len
=
np
.
array
([
float
(
len
(
i
))
for
i
in
choices
])
if
(
2
*
len
(
choices
)
==
len
(
lls
)
and
"acc_mutual_info"
in
self
.
_metric_fn_list
.
keys
()
):
if
2
*
len
(
choices
)
==
len
(
lls
)
and
"acc_mutual_info"
in
use_metric
:
# then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods
# as we extend the args list with unconditional ("", continuation) pairs
...
...
@@ -1671,12 +1716,12 @@ class ConfigurableTask(Task):
gold
=
list
(
gold
)
# TODO: handle this better
elif
type
(
gold
)
is
not
type
(
result
)
and
not
(
"bypass"
in
se
lf
.
_metric
_fn_list
.
keys
()
or
isinstance
(
result
,
list
)
"bypass"
in
u
se_metric
or
isinstance
(
result
,
list
)
):
# cast gold to the same type as result
gold
=
type
(
result
)(
gold
)
for
metric
in
self
.
_
metric_
fn_
list
.
keys
()
:
for
metric
in
self
.
metric_list
:
if
self
.
multiple_target
:
# in the case where we have multiple targets,
# return true if any are true
...
...
@@ -1686,28 +1731,26 @@ class ConfigurableTask(Task):
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
# print(gold)
gold
=
[
gold
]
if
metric
==
"exact_match"
:
if
metric
.
name
==
"exact_match"
:
result
=
[
result
for
_
in
range
(
len
(
gold
))]
scores
=
self
.
_
metric
_
fn
_list
[
metric
]
(
scores
=
metric
.
fn
(
references
=
gold
,
predictions
=
result
,
**
self
.
_
metric
_fn_
kwargs
[
metric
]
,
**
metric
.
kwargs
,
)[
metric
]
result_score
=
1.0
if
scores
>
0.0
else
0.0
else
:
for
gold_option
in
gold
:
try
:
result_score
=
self
.
_
metric
_
fn
_list
[
metric
]
(
result_score
=
metric
.
fn
(
references
=
[
gold_option
],
predictions
=
[
result
],
**
self
.
_
metric
_fn_
kwargs
[
metric
]
,
**
metric
.
kwargs
,
)
except
(
TypeError
):
# TODO: this is hacky and I don't want to do it
result_score
=
self
.
_metric_fn_list
[
metric
](
[
gold_option
,
result
]
)
result_score
=
metric
.
fn
([
gold_option
,
result
])
if
isinstance
(
result_score
,
dict
):
# TODO: this handles the case where HF evaluate returns a dict.
result_score
=
result_score
[
metric
]
...
...
@@ -1718,13 +1761,13 @@ class ConfigurableTask(Task):
result_score
=
0.0
else
:
try
:
result_score
=
self
.
_
metric
_
fn
_list
[
metric
]
(
result_score
=
metric
.
fn
(
references
=
[
gold
],
predictions
=
[
result
],
**
self
.
_
metric
_fn_
kwargs
[
metric
]
,
**
metric
.
kwargs
,
)
except
TypeError
:
# needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score
=
self
.
_
metric
_
fn
_list
[
metric
]
([
gold
,
result
])
result_score
=
metric
.
fn
([
gold
,
result
])
if
isinstance
(
result_score
,
dict
):
# TODO: this handles the case where HF evaluate returns a dict.
# This allows for multiple metrics to be returned from the same function
...
...
@@ -1741,10 +1784,10 @@ class ConfigurableTask(Task):
return
result_dict
def
aggregation
(
self
)
->
dict
:
return
self
.
_
aggregation_list
return
{
k
.
name
:
k
.
aggregation_
fn
for
k
in
self
.
metric_
list
}
def
higher_is_better
(
self
)
->
dict
:
return
self
.
_
higher_is_better
return
{
k
.
name
:
k
.
higher_is_better
for
k
in
self
.
metric_list
}
def
get_config
(
self
,
key
:
str
)
->
Any
:
return
getattr
(
self
.
_config
,
key
,
None
)
...
...
lm_eval/evaluator.py
View file @
28c78d30
...
...
@@ -287,7 +287,7 @@ def simple_evaluate(
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
def
_adjust_config
(
task_dict
)
:
def
_adjust_config
(
task_dict
:
dict
[
str
,
"Task"
])
->
dict
[
str
,
"Task"
]
:
adjusted_task_dict
=
{}
for
task_name
,
task_obj
in
task_dict
.
items
():
if
isinstance
(
task_obj
,
dict
):
...
...
lm_eval/evaluator_utils.py
View file @
28c78d30
...
...
@@ -12,7 +12,7 @@ from lm_eval.api.metrics import (
pooled_sample_stderr
,
stderr_for_metric
,
)
from
lm_eval.api.task
import
Task
from
lm_eval.api.task
import
ConfigurableTask
,
Task
from
lm_eval.utils
import
positional_deprecated
...
...
@@ -58,7 +58,7 @@ class TaskOutput:
group_alias
=
None
,
is_group
=
None
,
):
self
.
task
=
task
self
.
task
:
Union
[
Task
,
ConfigurableTask
]
=
task
self
.
task_config
=
task_config
self
.
task_name
=
task_name
self
.
group_name
=
group_name
...
...
lm_eval/filters/__init__.py
View file @
28c78d30
from
functools
import
partial
from
typing
import
List
from
typing
import
List
,
Union
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.registry
import
get_filter
...
...
@@ -8,7 +8,7 @@ from . import custom, extraction, selection, transformation
def
build_filter_ensemble
(
filter_name
:
str
,
components
:
List
[
L
ist
[
str
]]
filter_name
:
str
,
components
:
list
[
Union
[
list
[
dict
],
l
ist
[
str
]]
]
)
->
FilterEnsemble
:
"""
Create a filtering pipeline.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment