Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
69d14fb3
Commit
69d14fb3
authored
Jul 21, 2025
by
Baber
Browse files
cleanup
parent
57b8c0b1
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
201 additions
and
259 deletions
+201
-259
lm_eval/api/filter.py
lm_eval/api/filter.py
+6
-4
lm_eval/api/registry.py
lm_eval/api/registry.py
+15
-12
lm_eval/api/task.py
lm_eval/api/task.py
+85
-111
lm_eval/config/metric.py
lm_eval/config/metric.py
+12
-9
lm_eval/config/task.py
lm_eval/config/task.py
+46
-47
lm_eval/config/template.py
lm_eval/config/template.py
+21
-19
lm_eval/config/utils.py
lm_eval/config/utils.py
+6
-6
lm_eval/filters/extraction.py
lm_eval/filters/extraction.py
+10
-51
No files found.
lm_eval/api/filter.py
View file @
69d14fb3
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Iterable
,
List
,
Union
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.instance
import
Instance
...
@@ -20,7 +20,9 @@ class Filter(ABC):
...
@@ -20,7 +20,9 @@ class Filter(ABC):
"""
"""
@
abstractmethod
@
abstractmethod
def
apply
(
self
,
resps
:
Union
[
List
,
Iterable
],
docs
:
List
[
dict
])
->
Iterable
:
def
apply
(
self
,
resps
:
Iterable
[
list
[
str
]],
docs
:
Iterable
[
dict
]
)
->
Iterable
[
list
[
str
]]:
"""
"""
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
...
@@ -40,9 +42,9 @@ class FilterEnsemble:
...
@@ -40,9 +42,9 @@ class FilterEnsemble:
"""
"""
name
:
str
name
:
str
filters
:
L
ist
[
type
[
Filter
]]
filters
:
l
ist
[
type
[
Filter
]]
def
apply
(
self
,
instances
:
L
ist
[
Instance
])
->
None
:
def
apply
(
self
,
instances
:
l
ist
[
Instance
])
->
None
:
resps
,
docs
=
zip
(
*
((
inst
.
resps
,
inst
.
doc
)
for
inst
in
instances
))
resps
,
docs
=
zip
(
*
((
inst
.
resps
,
inst
.
doc
)
for
inst
in
instances
))
resps
,
docs
=
list
(
resps
),
list
(
docs
)
resps
,
docs
=
list
(
resps
),
list
(
docs
)
...
...
lm_eval/api/registry.py
View file @
69d14fb3
from
__future__
import
annotations
import
logging
import
logging
from
typing
import
TYPE_CHECKING
,
Callable
,
Dict
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -36,13 +38,14 @@ def register_model(*names):
...
@@ -36,13 +38,14 @@ def register_model(*names):
return
decorate
return
decorate
def
get_model
(
model_name
:
str
)
->
type
[
"
LM
"
]:
def
get_model
(
model_name
:
str
)
->
type
[
LM
]:
try
:
try
:
return
MODEL_REGISTRY
[
model_name
]
return
MODEL_REGISTRY
[
model_name
]
except
KeyError
:
except
KeyError
as
err
:
raise
ValueError
(
available_models
=
", "
.
join
(
MODEL_REGISTRY
.
keys
())
f
"Attempted to load model '
{
model_name
}
', but no model for this name found! Supported model names:
{
', '
.
join
(
MODEL_REGISTRY
.
keys
())
}
"
raise
KeyError
(
)
f
"Model '
{
model_name
}
' not found. Available models:
{
available_models
}
"
)
from
err
TASK_REGISTRY
=
{}
TASK_REGISTRY
=
{}
...
@@ -81,7 +84,7 @@ def register_group(name):
...
@@ -81,7 +84,7 @@ def register_group(name):
OUTPUT_TYPE_REGISTRY
=
{}
OUTPUT_TYPE_REGISTRY
=
{}
METRIC_REGISTRY
=
{}
METRIC_REGISTRY
=
{}
METRIC_AGGREGATION_REGISTRY
=
{}
METRIC_AGGREGATION_REGISTRY
=
{}
AGGREGATION_REGISTRY
:
D
ict
[
str
,
Callable
[[],
D
ict
[
str
,
Callable
]]]
=
{}
AGGREGATION_REGISTRY
:
d
ict
[
str
,
Callable
[[],
d
ict
[
str
,
Callable
]]]
=
{}
HIGHER_IS_BETTER_REGISTRY
=
{}
HIGHER_IS_BETTER_REGISTRY
=
{}
FILTER_REGISTRY
=
{}
FILTER_REGISTRY
=
{}
...
@@ -125,7 +128,7 @@ def register_metric(**args):
...
@@ -125,7 +128,7 @@ def register_metric(**args):
return
decorate
return
decorate
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
)
->
Optional
[
Callable
]
:
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
)
->
Callable
[...,
Any
]
|
None
:
if
not
hf_evaluate_metric
:
if
not
hf_evaluate_metric
:
if
name
in
METRIC_REGISTRY
:
if
name
in
METRIC_REGISTRY
:
return
METRIC_REGISTRY
[
name
]
return
METRIC_REGISTRY
[
name
]
...
@@ -157,21 +160,21 @@ def register_aggregation(name: str):
...
@@ -157,21 +160,21 @@ def register_aggregation(name: str):
return
decorate
return
decorate
def
get_aggregation
(
name
:
str
)
->
Optional
[
Callable
[[],
Dict
[
str
,
Callable
]]]
:
def
get_aggregation
(
name
:
str
)
->
Callable
[...,
Any
]
|
None
:
try
:
try
:
return
AGGREGATION_REGISTRY
[
name
]
return
AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
not a registered aggregation metric!"
)
eval_logger
.
warning
(
f
"
{
name
}
not a registered aggregation metric!"
)
def
get_metric_aggregation
(
name
:
str
)
->
Optional
[
Callable
[[],
D
ict
[
str
,
Callable
]]
]
:
def
get_metric_aggregation
(
name
:
str
)
->
Callable
[[],
d
ict
[
str
,
Callable
]]
|
None
:
try
:
try
:
return
METRIC_AGGREGATION_REGISTRY
[
name
]
return
METRIC_AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!"
)
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!"
)
def
is_higher_better
(
metric_name
:
str
)
->
Optional
[
bool
]
:
def
is_higher_better
(
metric_name
:
str
)
->
bool
|
None
:
try
:
try
:
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
except
KeyError
:
except
KeyError
:
...
@@ -192,7 +195,7 @@ def register_filter(name: str):
...
@@ -192,7 +195,7 @@ def register_filter(name: str):
return
decorate
return
decorate
def
get_filter
(
filter_name
:
Union
[
str
,
Callable
]
)
->
Callable
:
def
get_filter
(
filter_name
:
str
|
Callable
)
->
Callable
:
try
:
try
:
return
FILTER_REGISTRY
[
filter_name
]
return
FILTER_REGISTRY
[
filter_name
]
except
KeyError
as
e
:
except
KeyError
as
e
:
...
...
lm_eval/api/task.py
View file @
69d14fb3
from
__future__
import
annotations
import
abc
import
abc
import
ast
import
ast
import
logging
import
logging
...
@@ -8,15 +10,7 @@ from copy import deepcopy
...
@@ -8,15 +10,7 @@ from copy import deepcopy
from
typing
import
(
from
typing
import
(
TYPE_CHECKING
,
TYPE_CHECKING
,
Any
,
Any
,
Dict
,
Iterable
,
Iterator
,
List
,
Literal
,
Literal
,
Mapping
,
Optional
,
Tuple
,
Union
,
)
)
import
datasets
import
datasets
...
@@ -57,23 +51,23 @@ class Task(abc.ABC):
...
@@ -57,23 +51,23 @@ class Task(abc.ABC):
{"question": ..., question, answer)
{"question": ..., question, answer)
"""
"""
VERSION
:
Optional
[
Union
[
int
,
str
]]
=
None
VERSION
:
int
|
str
|
None
=
None
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
# or a path to a custom `datasets` loading script.
# or a path to a custom `datasets` loading script.
DATASET_PATH
:
Optional
[
str
]
=
None
DATASET_PATH
:
str
|
None
=
None
# The name of a subset within `DATASET_PATH`.
# The name of a subset within `DATASET_PATH`.
DATASET_NAME
:
Optional
[
str
]
=
None
DATASET_NAME
:
str
|
None
=
None
OUTPUT_TYPE
:
Optional
[
OutputType
]
=
None
OUTPUT_TYPE
:
OutputType
|
None
=
None
def
__init__
(
def
__init__
(
self
,
self
,
data_dir
:
Optional
[
str
]
=
None
,
data_dir
:
str
|
None
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
str
|
None
=
None
,
download_mode
:
Optional
[
datasets
.
DownloadMode
]
=
None
,
download_mode
:
datasets
.
DownloadMode
|
None
=
None
,
config
:
Optional
[
Mapping
]
=
None
,
# Union[dict, TaskConfig]
config
:
Mapping
|
None
=
None
,
# Union[dict, TaskConfig]
)
->
None
:
)
->
None
:
"""
"""
:param data_dir: str
:param data_dir: str
...
@@ -97,21 +91,21 @@ class Task(abc.ABC):
...
@@ -97,21 +91,21 @@ class Task(abc.ABC):
Fresh download and fresh dataset.
Fresh download and fresh dataset.
"""
"""
self
.
download
(
data_dir
,
cache_dir
,
download_mode
)
self
.
download
(
data_dir
,
cache_dir
,
download_mode
)
self
.
_training_docs
:
Optional
[
list
]
=
None
self
.
_training_docs
:
list
|
None
=
None
self
.
_fewshot_docs
:
Optional
[
list
]
=
None
self
.
_fewshot_docs
:
list
|
None
=
None
self
.
_instances
:
Optional
[
L
ist
[
Instance
]
]
=
None
self
.
_instances
:
l
ist
[
Instance
]
|
None
=
None
self
.
_config
:
TaskConfig
=
TaskConfig
.
from_yaml
({
**
config
})
self
.
_config
:
TaskConfig
=
TaskConfig
.
from_yaml
({
**
config
})
self
.
_filters
=
[
build_filter_ensemble
(
"none"
,
[(
"take_first"
,
None
)])]
self
.
_filters
=
[
build_filter_ensemble
(
"none"
,
[(
"take_first"
,
None
)])]
self
.
fewshot_rnd
:
Optional
[
random
.
Random
]
=
(
self
.
fewshot_rnd
:
random
.
Random
|
None
=
(
None
# purposely induce errors in case of improper usage
None
# purposely induce errors in case of improper usage
)
)
def
download
(
def
download
(
self
,
self
,
data_dir
:
Optional
[
str
]
=
None
,
data_dir
:
str
|
None
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
str
|
None
=
None
,
download_mode
=
None
,
download_mode
=
None
,
)
->
None
:
)
->
None
:
"""Downloads and returns the task dataset.
"""Downloads and returns the task dataset.
...
@@ -238,7 +232,7 @@ class Task(abc.ABC):
...
@@ -238,7 +232,7 @@ class Task(abc.ABC):
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
doc_to_target
(
self
,
doc
:
dict
)
->
Union
[
str
,
int
]
:
def
doc_to_target
(
self
,
doc
:
dict
)
->
str
|
int
:
pass
pass
# not an abstractmethod because not every language-only task has to implement this
# not an abstractmethod because not every language-only task has to implement this
...
@@ -254,16 +248,16 @@ class Task(abc.ABC):
...
@@ -254,16 +248,16 @@ class Task(abc.ABC):
def
build_all_requests
(
def
build_all_requests
(
self
,
self
,
*
,
*
,
limit
:
Union
[
int
,
None
]
=
None
,
limit
:
int
|
None
=
None
,
samples
:
Optional
[
L
ist
[
int
]
]
=
None
,
samples
:
l
ist
[
int
]
|
None
=
None
,
rank
:
int
=
0
,
rank
:
int
=
0
,
world_size
:
int
=
1
,
world_size
:
int
=
1
,
cache_requests
:
bool
=
False
,
cache_requests
:
bool
=
False
,
rewrite_requests_cache
:
bool
=
False
,
rewrite_requests_cache
:
bool
=
False
,
system_instruction
:
Optional
[
str
]
=
None
,
system_instruction
:
str
|
None
=
None
,
apply_chat_template
:
bool
=
False
,
apply_chat_template
:
bool
=
False
,
fewshot_as_multiturn
:
bool
=
False
,
fewshot_as_multiturn
:
bool
=
False
,
chat_template
:
Optional
[
Callable
]
=
None
,
chat_template
:
Callable
|
None
=
None
,
tokenizer_name
:
str
=
""
,
tokenizer_name
:
str
=
""
,
)
->
None
:
)
->
None
:
"""Build a set of Instances for a task, and store them in task.instances"""
"""Build a set of Instances for a task, and store them in task.instances"""
...
@@ -365,7 +359,7 @@ class Task(abc.ABC):
...
@@ -365,7 +359,7 @@ class Task(abc.ABC):
save_to_cache
(
file_name
=
cache_key
,
obj
=
instances
)
save_to_cache
(
file_name
=
cache_key
,
obj
=
instances
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
Union
[
list
[
dict
]
,
str
]
,
**
kwargs
):
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
list
[
dict
]
|
str
,
**
kwargs
):
"""Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
...
@@ -405,7 +399,7 @@ class Task(abc.ABC):
...
@@ -405,7 +399,7 @@ class Task(abc.ABC):
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metric scores
functions that aggregate a list of metric scores
"""
"""
pass
return
True
@
deprecated
(
"not used anymore"
)
@
deprecated
(
"not used anymore"
)
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
...
@@ -414,7 +408,7 @@ class Task(abc.ABC):
...
@@ -414,7 +408,7 @@ class Task(abc.ABC):
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
whether a higher value of the submetric is better
"""
"""
pass
return
True
def
get_config
(
self
,
key
:
str
)
->
Any
:
def
get_config
(
self
,
key
:
str
)
->
Any
:
return
getattr
(
self
.
_config
,
key
,
None
)
return
getattr
(
self
.
_config
,
key
,
None
)
...
@@ -488,13 +482,15 @@ class Task(abc.ABC):
...
@@ -488,13 +482,15 @@ class Task(abc.ABC):
example
=
self
.
doc_to_text
(
doc
)
example
=
self
.
doc_to_text
(
doc
)
return
description
+
labeled_examples
+
example
return
description
+
labeled_examples
+
example
def
apply_filters
(
self
)
->
Optional
[
L
ist
[
Instance
]
]
:
def
apply_filters
(
self
)
->
l
ist
[
Instance
]
|
None
:
"""Iterates over FilterEnsembles and applies them to instances"""
"""Iterates over FilterEnsembles and applies them to instances"""
if
hasattr
(
self
,
"_filters"
):
if
hasattr
(
self
,
"_filters"
)
and
self
.
_instances
:
for
f
in
self
.
_filters
:
for
f
in
self
.
_filters
:
f
.
apply
(
self
.
_instances
)
f
.
apply
(
self
.
_instances
)
else
:
else
:
eval_logger
.
warning
(
"No filter defined, passing through instances"
)
eval_logger
.
warning
(
"No filter defined or no instances, passing through instances"
)
return
self
.
_instances
return
self
.
_instances
def
dump_config
(
self
)
->
dict
:
def
dump_config
(
self
)
->
dict
:
...
@@ -505,9 +501,6 @@ class Task(abc.ABC):
...
@@ -505,9 +501,6 @@ class Task(abc.ABC):
def
set_config
(
self
,
key
:
str
,
value
:
Any
,
update
:
bool
=
False
)
->
None
:
def
set_config
(
self
,
key
:
str
,
value
:
Any
,
update
:
bool
=
False
)
->
None
:
"""Set or update the configuration for a given key."""
"""Set or update the configuration for a given key."""
if
key
is
None
:
raise
ValueError
(
"Key must be provided."
)
if
update
:
if
update
:
current_value
=
getattr
(
self
.
_config
,
key
,
{})
current_value
=
getattr
(
self
.
_config
,
key
,
{})
if
not
isinstance
(
current_value
,
dict
):
if
not
isinstance
(
current_value
,
dict
):
...
@@ -533,13 +526,13 @@ class Task(abc.ABC):
...
@@ -533,13 +526,13 @@ class Task(abc.ABC):
setattr
(
self
.
_config
,
"metric_list"
,
[
MetricConfig
(
name
=
metric_name
)])
setattr
(
self
.
_config
,
"metric_list"
,
[
MetricConfig
(
name
=
metric_name
)])
setattr
(
self
.
_config
,
"process_results"
,
lambda
*
args
:
{
"bypass"
:
0
})
setattr
(
self
.
_config
,
"process_results"
,
lambda
*
args
:
{
"bypass"
:
0
})
def
set_fewshot_seed
(
self
,
seed
:
Optional
[
int
]
=
None
)
->
None
:
def
set_fewshot_seed
(
self
,
seed
:
int
|
None
=
None
)
->
None
:
self
.
fewshot_rnd
=
random
.
Random
(
seed
)
self
.
fewshot_rnd
=
random
.
Random
(
seed
)
if
hasattr
(
self
,
"sampler"
):
if
hasattr
(
self
,
"sampler"
):
self
.
sampler
.
rnd
=
self
.
fewshot_rnd
self
.
sampler
.
rnd
=
self
.
fewshot_rnd
@
property
@
property
def
eval_docs
(
self
)
->
Union
[
datasets
.
Dataset
,
Iterable
[
dict
]
]
:
def
eval_docs
(
self
)
->
datasets
.
Dataset
|
Iterable
[
dict
]:
if
self
.
has_test_docs
():
if
self
.
has_test_docs
():
return
self
.
test_docs
()
return
self
.
test_docs
()
elif
self
.
has_validation_docs
():
elif
self
.
has_validation_docs
():
...
@@ -553,13 +546,13 @@ class Task(abc.ABC):
...
@@ -553,13 +546,13 @@ class Task(abc.ABC):
self
,
self
,
*
,
*
,
rank
:
int
=
0
,
rank
:
int
=
0
,
limit
:
Union
[
int
,
None
]
=
None
,
limit
:
int
|
None
=
None
,
world_size
:
int
=
1
,
world_size
:
int
=
1
,
samples
:
Optional
[
L
ist
[
int
]
]
=
None
,
samples
:
l
ist
[
int
]
|
None
=
None
,
)
->
Iterator
[
T
uple
[
int
,
Any
]]:
)
->
Iterator
[
t
uple
[
int
,
Any
]]:
if
samples
:
if
samples
:
n
=
len
(
self
.
eval_docs
)
n
=
len
(
self
.
eval_docs
)
assert
all
(
[
e
<
n
for
e
in
samples
]
),
(
assert
all
(
e
<
n
for
e
in
samples
),
(
f
"Elements of --samples should be in the interval [0,k-1] where k is the number of total examples. In this case, k=
{
n
}
."
f
"Elements of --samples should be in the interval [0,k-1] where k is the number of total examples. In this case, k=
{
n
}
."
)
)
eval_logger
.
info
(
eval_logger
.
info
(
...
@@ -592,7 +585,7 @@ class ConfigurableTask(Task):
...
@@ -592,7 +585,7 @@ class ConfigurableTask(Task):
data_dir
=
None
,
data_dir
=
None
,
cache_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
download_mode
=
None
,
config
:
Optional
[
dict
]
=
None
,
config
:
dict
|
None
=
None
,
)
->
None
:
)
->
None
:
# Get pre-configured attributes
# Get pre-configured attributes
self
.
_config
=
self
.
CONFIG
self
.
_config
=
self
.
CONFIG
...
@@ -610,8 +603,7 @@ class ConfigurableTask(Task):
...
@@ -610,8 +603,7 @@ class ConfigurableTask(Task):
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
)
)
if
isinstance
(
self
.
config
.
metadata
,
dict
):
if
isinstance
(
self
.
config
.
metadata
,
dict
)
and
"version"
in
self
.
config
.
metadata
:
if
"version"
in
self
.
config
.
metadata
:
self
.
VERSION
=
self
.
config
.
metadata
[
"version"
]
self
.
VERSION
=
self
.
config
.
metadata
[
"version"
]
if
self
.
config
.
output_type
is
not
None
:
if
self
.
config
.
output_type
is
not
None
:
...
@@ -698,18 +690,13 @@ class ConfigurableTask(Task):
...
@@ -698,18 +690,13 @@ class ConfigurableTask(Task):
else
:
else
:
test_target
=
str
(
test_target
)
test_target
=
str
(
test_target
)
if
test_choice
is
not
None
:
check_choices
=
test_choice
if
test_choice
is
not
None
else
[
test_target
]
check_choices
=
test_choice
else
:
check_choices
=
[
test_target
]
if
self
.
config
.
doc_to_choice
is
not
None
:
if
self
.
config
.
doc_to_choice
is
not
None
:
for
choice
in
check_choices
:
for
choice
in
check_choices
:
choice_has_whitespace
=
True
if
choice
[
0
].
isspace
()
else
False
choice_has_whitespace
=
choice
[
0
].
isspace
()
delimiter_has_whitespace
=
(
delimiter_has_whitespace
=
(
True
self
.
config
.
target_delimiter
.
rstrip
()
if
self
.
config
.
target_delimiter
.
rstrip
()
!=
self
.
config
.
target_delimiter
!=
self
.
config
.
target_delimiter
else
False
)
)
if
delimiter_has_whitespace
and
choice_has_whitespace
:
if
delimiter_has_whitespace
and
choice_has_whitespace
:
...
@@ -722,7 +709,7 @@ class ConfigurableTask(Task):
...
@@ -722,7 +709,7 @@ class ConfigurableTask(Task):
)
)
def
download
(
def
download
(
self
,
dataset_kwargs
:
Optional
[
D
ict
[
str
,
Any
]
]
=
None
,
**
kwargs
self
,
dataset_kwargs
:
d
ict
[
str
,
Any
]
|
None
=
None
,
**
kwargs
)
->
None
:
)
->
None
:
from
packaging.version
import
parse
as
vparse
from
packaging.version
import
parse
as
vparse
...
@@ -746,24 +733,15 @@ class ConfigurableTask(Task):
...
@@ -746,24 +733,15 @@ class ConfigurableTask(Task):
)
)
def
has_training_docs
(
self
)
->
bool
:
def
has_training_docs
(
self
)
->
bool
:
if
self
.
config
.
training_split
is
not
None
:
return
self
.
config
.
training_split
is
not
None
return
True
else
:
return
False
def
has_validation_docs
(
self
)
->
bool
:
def
has_validation_docs
(
self
)
->
bool
:
if
self
.
config
.
validation_split
is
not
None
:
return
self
.
config
.
validation_split
is
not
None
return
True
else
:
return
False
def
has_test_docs
(
self
)
->
bool
:
def
has_test_docs
(
self
)
->
bool
:
if
self
.
config
.
test_split
is
not
None
:
return
self
.
config
.
test_split
is
not
None
return
True
else
:
return
False
def
training_docs
(
self
)
->
Optional
[
datasets
.
Dataset
]
:
def
training_docs
(
self
)
->
datasets
.
Dataset
|
None
:
if
self
.
has_training_docs
():
if
self
.
has_training_docs
():
if
self
.
config
.
process_docs
is
not
None
:
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
return
self
.
config
.
process_docs
(
...
@@ -771,7 +749,7 @@ class ConfigurableTask(Task):
...
@@ -771,7 +749,7 @@ class ConfigurableTask(Task):
)
)
return
self
.
dataset
[
self
.
config
.
training_split
]
return
self
.
dataset
[
self
.
config
.
training_split
]
def
validation_docs
(
self
)
->
Optional
[
datasets
.
Dataset
]
:
def
validation_docs
(
self
)
->
datasets
.
Dataset
|
None
:
if
self
.
has_validation_docs
():
if
self
.
has_validation_docs
():
if
self
.
config
.
process_docs
is
not
None
:
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
return
self
.
config
.
process_docs
(
...
@@ -779,7 +757,7 @@ class ConfigurableTask(Task):
...
@@ -779,7 +757,7 @@ class ConfigurableTask(Task):
)
)
return
self
.
dataset
[
self
.
config
.
validation_split
]
return
self
.
dataset
[
self
.
config
.
validation_split
]
def
test_docs
(
self
)
->
Optional
[
datasets
.
Dataset
]
:
def
test_docs
(
self
)
->
datasets
.
Dataset
|
None
:
if
self
.
has_test_docs
():
if
self
.
has_test_docs
():
if
self
.
config
.
process_docs
is
not
None
:
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
self
.
dataset
[
self
.
config
.
test_split
])
return
self
.
config
.
process_docs
(
self
.
dataset
[
self
.
config
.
test_split
])
...
@@ -792,8 +770,11 @@ class ConfigurableTask(Task):
...
@@ -792,8 +770,11 @@ class ConfigurableTask(Task):
return
docs
return
docs
# Fallback to parent implementation
# Fallback to parent implementation
if
_num_fewshot
:
=
getattr
(
self
.
config
,
"num_fewshot"
):
if
(
if
isinstance
(
_num_fewshot
,
int
)
and
_num_fewshot
>
0
:
(
_num_fewshot
:
=
self
.
config
.
num_fewshot
)
and
isinstance
(
_num_fewshot
,
int
)
and
_num_fewshot
>
0
):
eval_logger
.
warning
(
eval_logger
.
warning
(
f
"[Task:
{
self
.
config
.
task
}
] "
f
"[Task:
{
self
.
config
.
task
}
] "
"num_fewshot > 0 but no fewshot source configured. "
"num_fewshot > 0 but no fewshot source configured. "
...
@@ -804,10 +785,10 @@ class ConfigurableTask(Task):
...
@@ -804,10 +785,10 @@ class ConfigurableTask(Task):
@
staticmethod
@
staticmethod
def
append_target_question
(
def
append_target_question
(
labeled_examples
:
L
ist
[
D
ict
[
str
,
str
]],
labeled_examples
:
l
ist
[
d
ict
[
str
,
str
]],
question
:
str
,
question
:
str
,
fewshot_as_multiturn
:
bool
=
False
,
fewshot_as_multiturn
:
bool
=
False
,
gen_prefix
:
Optional
[
str
]
=
None
,
gen_prefix
:
str
|
None
=
None
,
)
->
None
:
)
->
None
:
"""Adds a target question to the labeled examples list.
"""Adds a target question to the labeled examples list.
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
...
@@ -831,12 +812,12 @@ class ConfigurableTask(Task):
...
@@ -831,12 +812,12 @@ class ConfigurableTask(Task):
self
,
self
,
doc
:
dict
,
doc
:
dict
,
num_fewshot
:
int
,
num_fewshot
:
int
,
system_instruction
:
Optional
[
str
]
=
None
,
system_instruction
:
str
|
None
=
None
,
apply_chat_template
:
bool
=
False
,
apply_chat_template
:
bool
=
False
,
fewshot_as_multiturn
:
bool
=
False
,
fewshot_as_multiturn
:
bool
=
False
,
chat_template
:
Optional
[
Callable
]
=
None
,
chat_template
:
Callable
|
None
=
None
,
gen_prefix
:
Optional
[
str
]
=
None
,
gen_prefix
:
str
|
None
=
None
,
)
->
Union
[
str
,
L
ist
[
str
]
,
None
]
:
)
->
str
|
l
ist
[
str
]
|
None
:
"""Returns a fewshot context string that is made up of a prepended description
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
...
@@ -857,10 +838,7 @@ class ConfigurableTask(Task):
...
@@ -857,10 +838,7 @@ class ConfigurableTask(Task):
:returns: str
:returns: str
The fewshot context.
The fewshot context.
"""
"""
if
apply_chat_template
:
labeled_examples
=
[]
if
apply_chat_template
else
""
labeled_examples
=
[]
else
:
labeled_examples
=
""
# get task description
# get task description
if
description
:
=
self
.
config
.
description
:
if
description
:
=
self
.
config
.
description
:
...
@@ -930,7 +908,7 @@ class ConfigurableTask(Task):
...
@@ -930,7 +908,7 @@ class ConfigurableTask(Task):
labeled_examples_list
.
append
(
labeled_examples_list
.
append
(
chat_template
(
chat_template
(
chat
,
chat
,
add_generation_prompt
=
False
if
gen_prefix
else
True
,
add_generation_prompt
=
not
gen_prefix
,
)
)
)
)
return
labeled_examples_list
return
labeled_examples_list
...
@@ -954,7 +932,7 @@ class ConfigurableTask(Task):
...
@@ -954,7 +932,7 @@ class ConfigurableTask(Task):
# return lm.apply_chat_template(labeled_examples)
# return lm.apply_chat_template(labeled_examples)
return
chat_template
(
return
chat_template
(
labeled_examples
,
labeled_examples
,
add_generation_prompt
=
False
if
gen_prefix
else
True
,
add_generation_prompt
=
not
gen_prefix
,
)
)
else
:
else
:
prefix
=
(
prefix
=
(
...
@@ -975,7 +953,7 @@ class ConfigurableTask(Task):
...
@@ -975,7 +953,7 @@ class ConfigurableTask(Task):
else
:
else
:
return
labeled_examples
+
str
(
example
)
+
prefix
return
labeled_examples
+
str
(
example
)
+
prefix
def
apply_filters
(
self
)
->
Optional
[
L
ist
[
Instance
]
]
:
def
apply_filters
(
self
)
->
l
ist
[
Instance
]
|
None
:
"""Iterates over FilterEnsembles and applies them to instances"""
"""Iterates over FilterEnsembles and applies them to instances"""
if
hasattr
(
self
,
"_filters"
):
if
hasattr
(
self
,
"_filters"
):
for
f
in
self
.
_filters
:
for
f
in
self
.
_filters
:
...
@@ -1015,9 +993,7 @@ class ConfigurableTask(Task):
...
@@ -1015,9 +993,7 @@ class ConfigurableTask(Task):
"""
"""
return
doc
return
doc
def
doc_to_text
(
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
int
|
str
|
Callable
|
None
=
None
):
self
,
doc
:
dict
,
doc_to_text
:
Union
[
int
,
str
,
Callable
,
None
]
=
None
):
# if self.prompt is not None:
# if self.prompt is not None:
# doc_to_text = self.prompt
# doc_to_text = self.prompt
if
doc_to_text
is
not
None
:
if
doc_to_text
is
not
None
:
...
@@ -1053,9 +1029,7 @@ class ConfigurableTask(Task):
...
@@ -1053,9 +1029,7 @@ class ConfigurableTask(Task):
print
(
type
(
doc_to_text
))
print
(
type
(
doc_to_text
))
raise
TypeError
raise
TypeError
def
doc_to_target
(
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
=
None
)
->
int
|
str
|
list
[
int
]:
self
,
doc
:
dict
,
doc_to_target
=
None
)
->
Union
[
int
,
str
,
list
[
int
]]:
# if self.prompt is not None:
# if self.prompt is not None:
# doc_to_target = self.prompt
# doc_to_target = self.prompt
if
doc_to_target
is
not
None
:
if
doc_to_target
is
not
None
:
...
@@ -1104,8 +1078,8 @@ class ConfigurableTask(Task):
...
@@ -1104,8 +1078,8 @@ class ConfigurableTask(Task):
def
doc_to_choice
(
def
doc_to_choice
(
self
,
self
,
doc
:
dict
,
doc
:
dict
,
doc_to_choice
:
Union
[
str
,
list
,
dict
,
Callable
[...,
list
[
str
]]
,
None
]
=
None
,
doc_to_choice
:
str
|
list
|
dict
|
Callable
[...,
list
[
str
]]
|
None
=
None
,
)
->
L
ist
[
str
]:
)
->
l
ist
[
str
]:
# if self.prompt is not None:
# if self.prompt is not None:
# doc_to_choice = self.prompt
# doc_to_choice = self.prompt
if
doc_to_choice
is
not
None
:
if
doc_to_choice
is
not
None
:
...
@@ -1132,7 +1106,7 @@ class ConfigurableTask(Task):
...
@@ -1132,7 +1106,7 @@ class ConfigurableTask(Task):
else
:
else
:
raise
TypeError
raise
TypeError
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
=
None
)
->
Union
[
int
,
str
,
list
,
None
]
:
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
=
None
)
->
int
|
str
|
list
|
None
:
if
doc_to_image
is
not
None
:
if
doc_to_image
is
not
None
:
doc_to_image
=
doc_to_image
doc_to_image
=
doc_to_image
elif
self
.
config
.
doc_to_image
is
not
None
:
elif
self
.
config
.
doc_to_image
is
not
None
:
...
@@ -1155,7 +1129,7 @@ class ConfigurableTask(Task):
...
@@ -1155,7 +1129,7 @@ class ConfigurableTask(Task):
else
:
else
:
return
None
return
None
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
=
None
)
->
Union
[
int
,
str
,
list
,
None
]
:
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
=
None
)
->
int
|
str
|
list
|
None
:
if
doc_to_audio
is
not
None
:
if
doc_to_audio
is
not
None
:
doc_to_audio
=
doc_to_audio
doc_to_audio
=
doc_to_audio
elif
self
.
config
.
doc_to_audio
is
not
None
:
elif
self
.
config
.
doc_to_audio
is
not
None
:
...
@@ -1178,7 +1152,7 @@ class ConfigurableTask(Task):
...
@@ -1178,7 +1152,7 @@ class ConfigurableTask(Task):
else
:
else
:
return
None
return
None
def
doc_to_prefix
(
self
,
doc
:
dict
)
->
Optional
[
str
]
:
def
doc_to_prefix
(
self
,
doc
:
dict
)
->
str
|
None
:
if
(
gen_prefix
:
=
self
.
config
.
gen_prefix
)
is
not
None
:
if
(
gen_prefix
:
=
self
.
config
.
gen_prefix
)
is
not
None
:
if
gen_prefix
in
self
.
features
:
if
gen_prefix
in
self
.
features
:
return
doc
[
gen_prefix
]
return
doc
[
gen_prefix
]
...
@@ -1188,7 +1162,7 @@ class ConfigurableTask(Task):
...
@@ -1188,7 +1162,7 @@ class ConfigurableTask(Task):
def
construct_requests
(
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
str
,
**
kwargs
self
,
doc
:
dict
,
ctx
:
str
,
**
kwargs
)
->
Union
[
L
ist
[
Instance
]
,
Instance
]
:
)
->
l
ist
[
Instance
]
|
Instance
:
apply_chat_template
=
kwargs
.
pop
(
"apply_chat_template"
,
False
)
apply_chat_template
=
kwargs
.
pop
(
"apply_chat_template"
,
False
)
chat_template
:
Callable
|
None
=
kwargs
.
pop
(
"chat_template"
,
None
)
chat_template
:
Callable
|
None
=
kwargs
.
pop
(
"chat_template"
,
None
)
...
@@ -1324,7 +1298,7 @@ class ConfigurableTask(Task):
...
@@ -1324,7 +1298,7 @@ class ConfigurableTask(Task):
elif
self
.
OUTPUT_TYPE
==
"multiple_choice"
:
elif
self
.
OUTPUT_TYPE
==
"multiple_choice"
:
lls
,
is_greedy
=
zip
(
*
results
)
lls
,
is_greedy
=
zip
(
*
results
)
# retrieve choices in
L
ist[str] form, to compute choice lengths, etc.
# retrieve choices in
l
ist[str] form, to compute choice lengths, etc.
choices
=
self
.
doc_to_choice
(
doc
)
choices
=
self
.
doc_to_choice
(
doc
)
completion_len
=
np
.
array
([
float
(
len
(
i
))
for
i
in
choices
])
completion_len
=
np
.
array
([
float
(
len
(
i
))
for
i
in
choices
])
...
@@ -1371,7 +1345,7 @@ class ConfigurableTask(Task):
...
@@ -1371,7 +1345,7 @@ class ConfigurableTask(Task):
if
self
.
multiple_target
:
if
self
.
multiple_target
:
acc
=
1.0
if
pred
in
gold
else
0.0
acc
=
1.0
if
pred
in
gold
else
0.0
acc_norm
=
1.0
if
pred_norm
in
gold
else
0.0
acc_norm
=
1.0
if
pred_norm
in
gold
else
0.0
exact_match
=
int
(
any
(
[
is_greedy
[
i
]
if
i
!=
-
100
else
0
for
i
in
gold
]
))
exact_match
=
int
(
any
(
is_greedy
[
i
]
if
i
!=
-
100
else
0
for
i
in
gold
))
else
:
else
:
acc
=
1.0
if
pred
==
gold
else
0.0
acc
=
1.0
if
pred
==
gold
else
0.0
acc_norm
=
1.0
if
pred_norm
==
gold
else
0.0
acc_norm
=
1.0
if
pred_norm
==
gold
else
0.0
...
@@ -1413,7 +1387,7 @@ class ConfigurableTask(Task):
...
@@ -1413,7 +1387,7 @@ class ConfigurableTask(Task):
# it assumes that doc_to_target returns a number.
# it assumes that doc_to_target returns a number.
choices
=
self
.
doc_to_choice
(
doc
)
choices
=
self
.
doc_to_choice
(
doc
)
gold
=
choices
[
gold
]
gold
=
choices
[
gold
]
for
metric
in
self
.
_metric_fn_list
.
keys
()
:
for
metric
in
self
.
_metric_fn_list
:
try
:
try
:
result_score
=
self
.
_metric_fn_list
[
metric
](
result_score
=
self
.
_metric_fn_list
[
metric
](
references
=
[
gold
]
if
not
isinstance
(
gold
,
list
)
else
gold
,
references
=
[
gold
]
if
not
isinstance
(
gold
,
list
)
else
gold
,
...
@@ -1447,7 +1421,7 @@ class ConfigurableTask(Task):
...
@@ -1447,7 +1421,7 @@ class ConfigurableTask(Task):
return
getattr
(
self
.
_config
,
key
,
None
)
return
getattr
(
self
.
_config
,
key
,
None
)
@
property
@
property
def
task_name
(
self
)
->
Optional
[
str
]
:
def
task_name
(
self
)
->
str
|
None
:
return
getattr
(
self
.
config
,
"task"
,
None
)
return
getattr
(
self
.
config
,
"task"
,
None
)
def
__repr__
(
self
):
def
__repr__
(
self
):
...
@@ -1465,7 +1439,7 @@ class MultipleChoiceTask(Task):
...
@@ -1465,7 +1439,7 @@ class MultipleChoiceTask(Task):
def
doc_to_target
(
self
,
doc
:
dict
)
->
str
:
def
doc_to_target
(
self
,
doc
:
dict
)
->
str
:
return
" "
+
doc
[
"choices"
][
doc
[
"gold"
]]
return
" "
+
doc
[
"choices"
][
doc
[
"gold"
]]
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
str
,
**
kwargs
)
->
L
ist
[
Instance
]:
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
str
,
**
kwargs
)
->
l
ist
[
Instance
]:
# TODO: add mutual info here?
# TODO: add mutual info here?
return
[
return
[
Instance
(
Instance
(
...
@@ -1478,7 +1452,7 @@ class MultipleChoiceTask(Task):
...
@@ -1478,7 +1452,7 @@ class MultipleChoiceTask(Task):
for
i
,
choice
in
enumerate
(
doc
[
"choices"
])
for
i
,
choice
in
enumerate
(
doc
[
"choices"
])
]
]
def
process_results
(
self
,
doc
:
dict
,
results
:
Iterable
[
T
uple
[
float
,
bool
]])
->
dict
:
def
process_results
(
self
,
doc
:
dict
,
results
:
Iterable
[
t
uple
[
float
,
bool
]])
->
dict
:
results
=
[
results
=
[
res
[
0
]
for
res
in
results
res
[
0
]
for
res
in
results
]
# only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
]
# only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
...
@@ -1512,7 +1486,7 @@ class PerplexityTask(Task):
...
@@ -1512,7 +1486,7 @@ class PerplexityTask(Task):
def
has_training_docs
(
self
)
->
bool
:
def
has_training_docs
(
self
)
->
bool
:
return
False
return
False
def
fewshot_examples
(
self
,
k
:
int
,
rnd
)
->
L
ist
:
def
fewshot_examples
(
self
,
k
:
int
,
rnd
)
->
l
ist
:
if
k
!=
0
:
if
k
!=
0
:
raise
ValueError
(
raise
ValueError
(
"The number of fewshot examples must be 0 for perplexity tasks."
"The number of fewshot examples must be 0 for perplexity tasks."
...
@@ -1543,7 +1517,7 @@ class PerplexityTask(Task):
...
@@ -1543,7 +1517,7 @@ class PerplexityTask(Task):
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
return
doc
return
doc
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
Optional
[
str
]
,
**
kwargs
):
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
str
|
None
,
**
kwargs
):
if
bool
(
ctx
):
if
bool
(
ctx
):
raise
ValueError
raise
ValueError
...
@@ -1555,7 +1529,7 @@ class PerplexityTask(Task):
...
@@ -1555,7 +1529,7 @@ class PerplexityTask(Task):
**
kwargs
,
**
kwargs
,
)
)
def
process_results
(
self
,
doc
:
dict
,
results
:
T
uple
[
float
])
->
dict
:
def
process_results
(
self
,
doc
:
dict
,
results
:
t
uple
[
float
])
->
dict
:
(
loglikelihood
,)
=
results
(
loglikelihood
,)
=
results
words
=
self
.
count_words
(
self
.
doc_to_target
(
doc
))
words
=
self
.
count_words
(
self
.
doc_to_target
(
doc
))
bytes_
=
self
.
count_bytes
(
self
.
doc_to_target
(
doc
))
bytes_
=
self
.
count_bytes
(
self
.
doc_to_target
(
doc
))
...
...
lm_eval/config/metric.py
View file @
69d14fb3
from
__future__
import
annotations
from
collections.abc
import
Callable
,
Mapping
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
Any
,
Callable
,
List
,
Optional
from
typing
import
Any
@
dataclass
@
dataclass
...
@@ -8,9 +11,9 @@ class MetricConfig:
...
@@ -8,9 +11,9 @@ class MetricConfig:
"""Encapsulates information about a single metric."""
"""Encapsulates information about a single metric."""
name
:
str
name
:
str
fn
:
Optional
[
Callable
]
=
None
fn
:
Callable
|
None
=
None
kwargs
:
Optional
[
dict
]
=
None
kwargs
:
Mapping
[
str
,
Any
]
|
None
=
None
aggregation_fn
:
Optional
[
Callable
]
=
None
aggregation_fn
:
Callable
|
None
=
None
higher_is_better
:
bool
=
True
higher_is_better
:
bool
=
True
hf_evaluate
:
bool
=
False
hf_evaluate
:
bool
=
False
is_elementwise
:
bool
=
True
is_elementwise
:
bool
=
True
...
@@ -20,7 +23,7 @@ class MetricConfig:
...
@@ -20,7 +23,7 @@ class MetricConfig:
return
self
.
name
return
self
.
name
@
cached_property
@
cached_property
def
aggregation
(
self
)
->
Callable
:
def
aggregation
(
self
)
->
Callable
[...,
Any
]
|
None
:
from
lm_eval.api.registry
import
get_aggregation
from
lm_eval.api.registry
import
get_aggregation
if
self
.
aggregation_fn
is
None
:
if
self
.
aggregation_fn
is
None
:
...
@@ -28,7 +31,7 @@ class MetricConfig:
...
@@ -28,7 +31,7 @@ class MetricConfig:
return
self
.
aggregation_fn
return
self
.
aggregation_fn
@
cached_property
@
cached_property
def
_higher_is_better
(
self
)
->
bool
:
def
_higher_is_better
(
self
)
->
bool
|
None
:
from
lm_eval.api.registry
import
is_higher_better
from
lm_eval.api.registry
import
is_higher_better
if
self
.
higher_is_better
is
None
:
if
self
.
higher_is_better
is
None
:
...
@@ -39,10 +42,10 @@ class MetricConfig:
...
@@ -39,10 +42,10 @@ class MetricConfig:
"""Calculates the metric using the provided function and arguments."""
"""Calculates the metric using the provided function and arguments."""
if
self
.
fn
is
None
:
if
self
.
fn
is
None
:
raise
ValueError
(
f
"Metric function for
{
self
.
name
}
is not defined."
)
raise
ValueError
(
f
"Metric function for
{
self
.
name
}
is not defined."
)
return
self
.
fn
(
*
args
,
**
{
**
self
.
kwargs
,
**
kwargs
})
return
self
.
fn
(
*
args
,
**
{
**
(
self
.
kwargs
or
{})
,
**
kwargs
})
def
compute_aggregation
(
self
,
values
:
List
[
Any
]
)
->
Any
:
def
compute_aggregation
(
self
,
*
args
,
**
kwargs
)
->
Any
:
"""Computes the aggregation of the metric values."""
"""Computes the aggregation of the metric values."""
if
self
.
aggregation_fn
is
None
:
if
self
.
aggregation_fn
is
None
:
raise
ValueError
(
f
"Aggregation function for
{
self
.
name
}
is not defined."
)
raise
ValueError
(
f
"Aggregation function for
{
self
.
name
}
is not defined."
)
return
self
.
aggregation_fn
(
value
s
)
return
self
.
aggregation_fn
(
*
args
,
**
kwarg
s
)
lm_eval/config/task.py
View file @
69d14fb3
from
__future__
import
annotations
import
logging
import
logging
from
collections.abc
import
Iterable
from
dataclasses
import
asdict
,
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Callable
,
Iterable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.instance
import
OutputType
from
lm_eval.api.instance
import
OutputType
...
@@ -20,8 +23,8 @@ class RepeatConfig:
...
@@ -20,8 +23,8 @@ class RepeatConfig:
"""Encapsulates information about a single repeat."""
"""Encapsulates information about a single repeat."""
repeats
:
int
=
1
repeats
:
int
=
1
metric_fn
:
Union
[
str
,
Callable
]
=
"pass@N"
metric_fn
:
str
|
Callable
=
"pass@N"
kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
kwargs
:
dict
|
None
=
field
(
default_factory
=
dict
)
@
dataclass
@
dataclass
...
@@ -38,11 +41,11 @@ class FewshotConfig:
...
@@ -38,11 +41,11 @@ class FewshotConfig:
# hack: this returns task.config.num_fewshot
# hack: this returns task.config.num_fewshot
# to keep in sync as it is runtime-modified
# to keep in sync as it is runtime-modified
num_fewshot
:
Callable
[[],
int
]
num_fewshot
:
Callable
[[],
int
]
split
:
Optional
[
str
]
=
None
split
:
str
|
None
=
None
sampler
:
Union
[
str
,
Callable
]
=
"default"
sampler
:
str
|
Callable
=
"default"
samples
:
Union
[
Callable
[[],
list
[
dict
]]
,
list
[
dict
]
,
None
]
=
None
samples
:
Callable
[[],
list
[
dict
]]
|
list
[
dict
]
|
None
=
None
process_docs
:
Optional
[
Callable
[[
list
[
dict
]],
Iterable
[
dict
]]
]
=
None
process_docs
:
Callable
[[
list
[
dict
]],
Iterable
[
dict
]]
|
None
=
None
fewshot_indices
:
Optional
[
list
[
int
]
]
=
None
fewshot_indices
:
list
[
int
]
|
None
=
None
rnd
:
int
=
field
(
init
=
False
,
default
=
False
)
rnd
:
int
=
field
(
init
=
False
,
default
=
False
)
def
__post_init__
(
self
)
->
None
:
def
__post_init__
(
self
)
->
None
:
...
@@ -65,22 +68,20 @@ class FewshotConfig:
...
@@ -65,22 +68,20 @@ class FewshotConfig:
def
_get_raw_docs
(
def
_get_raw_docs
(
self
,
dataset
self
,
dataset
)
->
Union
[
list
[
dict
]
,
Callable
[[],
Iterable
[
dict
]]
,
None
]
:
)
->
list
[
dict
]
|
Callable
[[],
Iterable
[
dict
]]
|
None
:
"""Get raw documents from configured source."""
"""Get raw documents from configured source."""
if
self
.
split
is
not
None
:
if
self
.
split
is
not
None
:
return
dataset
[
self
.
split
]
return
dataset
[
self
.
split
]
if
self
.
samples
is
not
None
:
if
self
.
samples
is
not
None
:
if
isinstance
(
self
.
samples
,
list
):
if
isinstance
(
self
.
samples
,
list
)
or
callable
(
self
.
samples
):
return
self
.
samples
elif
callable
(
self
.
samples
):
return
self
.
samples
return
self
.
samples
else
:
else
:
raise
TypeError
(
raise
TypeError
(
"samples must be either a list of dicts or a callable returning a list"
"samples must be either a list of dicts or a callable returning a list"
)
)
def
get_docs
(
self
,
dataset
)
->
Optional
[
Iterable
[
dict
]
]
:
def
get_docs
(
self
,
dataset
)
->
Iterable
[
dict
]
|
None
:
"""Get processed documents from configured source."""
"""Get processed documents from configured source."""
raw_docs
=
self
.
_get_raw_docs
(
dataset
)
raw_docs
=
self
.
_get_raw_docs
(
dataset
)
if
raw_docs
is
None
:
if
raw_docs
is
None
:
...
@@ -100,8 +101,8 @@ class FewshotConfig:
...
@@ -100,8 +101,8 @@ class FewshotConfig:
return
self
.
sampler
return
self
.
sampler
def
init_sampler
(
def
init_sampler
(
self
,
docs
:
list
[
dict
],
task
:
"
Task
"
,
rnd
=
None
,
fewshot_indices
=
None
self
,
docs
:
list
[
dict
],
task
:
Task
,
rnd
=
None
,
fewshot_indices
=
None
)
->
"
ContextSampler
"
:
)
->
ContextSampler
:
"""Initialize the sampler with the given documents and task."""
"""Initialize the sampler with the given documents and task."""
if
rnd
is
None
:
if
rnd
is
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -120,49 +121,49 @@ class FewshotConfig:
...
@@ -120,49 +121,49 @@ class FewshotConfig:
@
dataclass
@
dataclass
class
TaskConfig
(
dict
):
class
TaskConfig
(
dict
):
# task naming/registry
# task naming/registry
task
:
Optional
[
str
]
=
None
task
:
str
|
None
=
None
task_alias
:
Optional
[
str
]
=
None
task_alias
:
str
|
None
=
None
tag
:
Optional
[
Union
[
str
,
list
]]
=
None
tag
:
str
|
list
|
None
=
None
# HF dataset options.
# HF dataset options.
# which dataset to use,
# which dataset to use,
# and what splits for what purpose
# and what splits for what purpose
custom_dataset
:
Optional
[
Callable
]
=
None
custom_dataset
:
Callable
|
None
=
None
dataset_path
:
Optional
[
str
]
=
None
dataset_path
:
str
|
None
=
None
dataset_name
:
Optional
[
str
]
=
None
dataset_name
:
str
|
None
=
None
dataset_kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
dataset_kwargs
:
dict
|
None
=
field
(
default_factory
=
dict
)
training_split
:
Optional
[
str
]
=
None
training_split
:
str
|
None
=
None
validation_split
:
Optional
[
str
]
=
None
validation_split
:
str
|
None
=
None
test_split
:
Optional
[
str
]
=
None
test_split
:
str
|
None
=
None
fewshot_split
:
Optional
[
str
]
=
(
fewshot_split
:
str
|
None
=
(
None
# TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
None
# TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
)
)
# formatting / prompting options.
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
# see docs/advanced_task_guide.md for more info
process_docs
:
Optional
[
Callable
]
=
None
process_docs
:
Callable
|
None
=
None
doc_to_text
:
Optional
[
Union
[
Callable
,
str
]]
=
None
doc_to_text
:
Callable
|
str
|
None
=
None
doc_to_target
:
Optional
[
Union
[
Callable
,
str
]]
=
None
doc_to_target
:
Callable
|
str
|
None
=
None
doc_to_image
:
Union
[
Callable
,
str
,
None
]
=
None
doc_to_image
:
Callable
|
str
|
None
=
None
doc_to_audio
:
Union
[
Callable
,
str
,
None
]
=
None
doc_to_audio
:
Callable
|
str
|
None
=
None
unsafe_code
:
bool
=
False
unsafe_code
:
bool
=
False
doc_to_choice
:
Optional
[
Union
[
Callable
,
str
,
dict
,
list
]]
=
None
doc_to_choice
:
Callable
|
str
|
dict
|
list
|
None
=
None
process_results
:
Optional
[
Union
[
Callable
,
str
]]
=
None
process_results
:
Callable
|
str
|
None
=
None
use_prompt
:
Optional
[
str
]
=
None
use_prompt
:
str
|
None
=
None
description
:
str
=
""
description
:
str
=
""
target_delimiter
:
str
=
" "
target_delimiter
:
str
=
" "
fewshot_delimiter
:
str
=
"
\n\n
"
fewshot_delimiter
:
str
=
"
\n\n
"
fewshot_config
:
Optional
[
dict
]
=
None
fewshot_config
:
dict
|
None
=
None
# runtime configuration options
# runtime configuration options
num_fewshot
:
Optional
[
int
]
=
0
num_fewshot
:
int
|
None
=
0
generation_kwargs
:
Optional
[
dict
]
=
None
generation_kwargs
:
dict
|
None
=
None
# scoring options
# scoring options
metric_list
:
Optional
[
list
]
=
None
metric_list
:
list
|
None
=
None
output_type
:
OutputType
=
"generate_until"
output_type
:
OutputType
=
"generate_until"
repeats
:
int
=
1
repeats
:
int
=
1
filter_list
:
Optional
[
list
[
dict
]
]
=
None
filter_list
:
list
[
dict
]
|
None
=
None
should_decontaminate
:
bool
=
False
should_decontaminate
:
bool
=
False
doc_to_decontamination_query
:
Optional
[
str
]
=
None
doc_to_decontamination_query
:
str
|
None
=
None
gen_prefix
:
Optional
[
str
]
=
None
gen_prefix
:
str
|
None
=
None
metadata
:
Optional
[
dict
]
=
field
(
metadata
:
dict
|
None
=
field
(
default_factory
=
dict
default_factory
=
dict
)
# by default, not used in the code. allows for users to pass arbitrary info to tasks
)
# by default, not used in the code. allows for users to pass arbitrary info to tasks
...
@@ -215,9 +216,7 @@ class TaskConfig(dict):
...
@@ -215,9 +216,7 @@ class TaskConfig(dict):
fewshot_indices
=
_fewshot_cfg
.
get
(
"fewshot_indices"
,
None
),
fewshot_indices
=
_fewshot_cfg
.
get
(
"fewshot_indices"
,
None
),
)
)
def
_get_metric
(
def
_get_metric
(
self
,
metric_list
:
list
[
dict
]
|
None
=
None
)
->
list
[
MetricConfig
]:
self
,
metric_list
:
Optional
[
list
[
dict
]]
=
None
)
->
list
[
"MetricConfig"
]:
from
lm_eval.api.registry
import
(
from
lm_eval.api.registry
import
(
AGGREGATION_REGISTRY
,
AGGREGATION_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
...
@@ -314,7 +313,7 @@ class TaskConfig(dict):
...
@@ -314,7 +313,7 @@ class TaskConfig(dict):
return
metrics
return
metrics
@
property
@
property
def
get_filters
(
self
)
->
list
[
"
FilterConfig
"
]:
def
get_filters
(
self
)
->
list
[
FilterConfig
]:
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.filters
import
build_filter_ensemble
if
not
self
.
filter_list
:
if
not
self
.
filter_list
:
...
@@ -354,7 +353,7 @@ class TaskConfig(dict):
...
@@ -354,7 +353,7 @@ class TaskConfig(dict):
return
x
return
x
@
classmethod
@
classmethod
def
from_yaml
(
cls
,
data
:
dict
)
->
"
TaskConfig
"
:
def
from_yaml
(
cls
,
data
:
dict
)
->
TaskConfig
:
"""Create a TaskConfig instance from a YAML-like dictionary."""
"""Create a TaskConfig instance from a YAML-like dictionary."""
return
cls
(
**
data
)
return
cls
(
**
data
)
...
...
lm_eval/config/template.py
View file @
69d14fb3
from
__future__
import
annotations
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -11,19 +13,19 @@ class TemplateConfig:
...
@@ -11,19 +13,19 @@ class TemplateConfig:
"""Encapsulates information about a template."""
"""Encapsulates information about a template."""
template
:
str
template
:
str
doc_to_text
:
Union
[
str
,
Callable
[[
dict
],
str
]
]
doc_to_text
:
str
|
Callable
[[
dict
],
str
]
doc_to_choice
:
Union
[
str
,
list
,
Callable
[[
dict
],
list
]
]
doc_to_choice
:
str
|
list
|
Callable
[[
dict
],
list
]
doc_to_target
:
Union
[
int
,
Callable
[[
dict
],
int
]
]
doc_to_target
:
int
|
Callable
[[
dict
],
int
]
description
:
str
description
:
str
context_prefix
:
str
context_prefix
:
str
prefix_delimiter
:
str
prefix_delimiter
:
str
context_delimiter
:
str
context_delimiter
:
str
answer_suffix
:
str
answer_suffix
:
str
target_delimiter
:
str
target_delimiter
:
str
choice_format
:
Optional
[
str
]
choice_format
:
str
|
None
choice_delimiter
:
Optional
[
str
]
choice_delimiter
:
str
|
None
fewshot_delimiter
:
str
fewshot_delimiter
:
str
metric_list
:
Optional
[
Union
[
list
[
str
]
,
list
[
"
MetricConfig
"
]]]
=
field
(
metric_list
:
list
[
str
]
|
list
[
MetricConfig
]
|
None
=
field
(
default_factory
=
lambda
:
[
"acc"
,
"acc_norm"
]
default_factory
=
lambda
:
[
"acc"
,
"acc_norm"
]
)
)
...
@@ -40,19 +42,19 @@ class MCQTemplateConfig:
...
@@ -40,19 +42,19 @@ class MCQTemplateConfig:
Answer:` doc_to_choice(doc)` for each choice.
Answer:` doc_to_choice(doc)` for each choice.
"""
"""
doc_to_text
:
Union
[
str
,
Callable
[[
dict
],
str
]
]
doc_to_text
:
str
|
Callable
[[
dict
],
str
]
doc_to_choice
:
Union
[
str
,
list
,
Callable
[[
dict
],
list
]
]
doc_to_choice
:
str
|
list
|
Callable
[[
dict
],
list
]
doc_to_target
:
Union
[
int
,
Callable
[[
dict
],
int
]
]
doc_to_target
:
int
|
Callable
[[
dict
],
int
]
template
=
"mcq"
template
=
"mcq"
context_prefix
:
str
=
"Question:"
context_prefix
:
str
=
"Question:"
prefix_delimiter
:
str
=
" "
prefix_delimiter
:
str
=
" "
context_delimiter
:
str
=
"
\n
"
context_delimiter
:
str
=
"
\n
"
answer_suffix
:
str
=
"Answer:"
answer_suffix
:
str
=
"Answer:"
target_delimiter
:
str
=
"
\n
"
target_delimiter
:
str
=
"
\n
"
choice_format
:
Optional
[
str
]
=
"letters"
choice_format
:
str
|
None
=
"letters"
choice_delimiter
:
Optional
[
str
]
=
"
\n
"
choice_delimiter
:
str
|
None
=
"
\n
"
fewshot_delimiter
:
str
=
"
\n\n
"
fewshot_delimiter
:
str
=
"
\n\n
"
metric_list
:
Optional
[
list
[
"
MetricConfig
"
]]
=
field
(
default_factory
=
lambda
:
[
"acc"
])
metric_list
:
list
[
MetricConfig
]
|
None
=
field
(
default_factory
=
lambda
:
[
"acc"
])
@
dataclass
@
dataclass
...
@@ -63,9 +65,9 @@ class ClozeTemplateConfig:
...
@@ -63,9 +65,9 @@ class ClozeTemplateConfig:
Answer:` <doc_to_target(doc)>`
Answer:` <doc_to_target(doc)>`
"""
"""
doc_to_text
:
Union
[
str
,
Callable
[[
dict
],
str
]
]
doc_to_text
:
str
|
Callable
[[
dict
],
str
]
doc_to_choice
:
Union
[
str
,
list
,
Callable
[[
dict
],
list
]
]
doc_to_choice
:
str
|
list
|
Callable
[[
dict
],
list
]
doc_to_target
:
Union
[
int
,
Callable
[[
dict
],
int
]
]
doc_to_target
:
int
|
Callable
[[
dict
],
int
]
template
:
str
=
"cloze"
template
:
str
=
"cloze"
description
:
str
=
""
description
:
str
=
""
context_prefix
:
str
=
"Question:"
context_prefix
:
str
=
"Question:"
...
@@ -73,9 +75,9 @@ class ClozeTemplateConfig:
...
@@ -73,9 +75,9 @@ class ClozeTemplateConfig:
context_delimiter
:
str
=
"
\n
"
context_delimiter
:
str
=
"
\n
"
answer_suffix
:
str
=
"Answer:"
answer_suffix
:
str
=
"Answer:"
target_delimiter
:
str
=
" "
target_delimiter
:
str
=
" "
choice_format
:
Optional
[
str
]
=
None
choice_format
:
str
|
None
=
None
choice_delimiter
:
Optional
[
str
]
=
None
choice_delimiter
:
str
|
None
=
None
fewshot_delimiter
:
str
=
"
\n\n
"
fewshot_delimiter
:
str
=
"
\n\n
"
metric_list
:
Optional
[
list
[
"
MetricConfig
"
]]
=
field
(
metric_list
:
list
[
MetricConfig
]
|
None
=
field
(
default_factory
=
lambda
:
[
"acc"
,
"acc_norm"
]
default_factory
=
lambda
:
[
"acc"
,
"acc_norm"
]
)
)
lm_eval/config/utils.py
View file @
69d14fb3
from
__future__
import
annotations
from
inspect
import
getsource
from
inspect
import
getsource
from
typing
import
Any
,
Callable
,
Union
from
typing
import
Any
,
Callable
def
serialize_callable
(
def
serialize_callable
(
value
:
Union
[
Callable
[...,
Any
]
,
str
]
,
keep_callable
=
False
value
:
Callable
[...,
Any
]
|
str
,
keep_callable
=
False
)
->
Union
[
Callable
[...,
Any
]
,
str
]
:
)
->
Callable
[...,
Any
]
|
str
:
"""Serializes a given function or string.
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
If 'keep_callable' is True, the original callable is returned.
...
@@ -20,9 +22,7 @@ def serialize_callable(
...
@@ -20,9 +22,7 @@ def serialize_callable(
return
str
(
value
)
return
str
(
value
)
def
maybe_serialize
(
def
maybe_serialize
(
val
:
Callable
|
Any
,
keep_callable
=
False
)
->
Callable
|
Any
:
val
:
Union
[
Callable
,
Any
],
keep_callable
=
False
)
->
Union
[
Callable
,
Any
]:
"""Conditionally serializes a value if it is callable."""
"""Conditionally serializes a value if it is callable."""
return
(
return
(
...
...
lm_eval/filters/extraction.py
View file @
69d14fb3
import
re
import
re
import
sys
import
sys
import
unicodedata
import
unicodedata
from
collections.abc
import
Iterable
from
lm_eval.api.filter
import
Filter
from
lm_eval.api.filter
import
Filter
from
lm_eval.api.registry
import
register_filter
from
lm_eval.api.registry
import
register_filter
...
@@ -32,7 +33,9 @@ class RegexFilter(Filter):
...
@@ -32,7 +33,9 @@ class RegexFilter(Filter):
self
.
group_select
=
group_select
self
.
group_select
=
group_select
self
.
fallback
=
fallback
self
.
fallback
=
fallback
def
apply
(
self
,
resps
:
list
[
list
[
str
]],
docs
:
list
[
dict
])
->
list
[
list
[
str
]]:
def
apply
(
self
,
resps
:
Iterable
[
list
[
str
]],
docs
:
Iterable
[
dict
]
)
->
Iterable
[
list
[
str
]]:
# here, we assume we have a list, in which each element is
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# so we process each of these (same input/target response sets)
...
@@ -59,59 +62,13 @@ class RegexFilter(Filter):
...
@@ -59,59 +62,13 @@ class RegexFilter(Filter):
return
filtered_resps
return
filtered_resps
@
register_filter
(
"regex_pos"
)
class
POSFilter
(
Filter
):
""" """
def
__init__
(
self
,
regex_pattern
:
str
=
r
"\['(.*?)'\]"
,
group_select
=
0
,
fallback
=
None
,
**
kwargs
,
)
->
None
:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
super
().
__init__
(
**
kwargs
)
if
fallback
is
None
:
fallback
=
[
"invalid"
]
self
.
regex_pattern
=
regex_pattern
self
.
regex
=
re
.
compile
(
regex_pattern
)
self
.
group_select
=
group_select
self
.
fallback
=
fallback
def
apply
(
self
,
resps
,
docs
):
def
extract_tagged_tokens
(
text
):
# Extract tagged tokens list from text input using regex
tokens
=
re
.
findall
(
r
"\('([^']*)', '([^']*)'\)"
,
text
)
return
[(
token
,
pos
)
for
token
,
pos
in
tokens
]
def
extract_pos_tags
(
result
):
pos_tags
=
[]
if
isinstance
(
result
,
str
):
result
=
extract_tagged_tokens
(
result
)
pos_tags
.
extend
(
pos
for
_
,
pos
in
result
)
return
pos_tags
if
pos_tags
else
self
.
fallback
def
filter_set
(
inst
):
filtered
=
[]
for
resp
in
inst
:
match
=
extract_pos_tags
(
resp
)
filtered
.
append
(
match
)
return
filtered
filtered_resps
=
map
(
lambda
x
:
filter_set
(
x
),
resps
)
return
filtered_resps
@
register_filter
(
"remove_whitespace"
)
@
register_filter
(
"remove_whitespace"
)
class
WhitespaceFilter
(
Filter
):
class
WhitespaceFilter
(
Filter
):
"""Filters out leading whitespace from responses."""
"""Filters out leading whitespace from responses."""
def
apply
(
self
,
resps
:
list
[
list
[
str
]],
docs
:
list
[
dict
])
->
list
[
list
[
str
]]:
def
apply
(
self
,
resps
:
Iterable
[
list
[
str
]],
docs
:
Iterable
[
dict
]
)
->
Iterable
[
list
[
str
]]:
def
filter_set
(
inst
):
def
filter_set
(
inst
):
filtered_resp
=
[]
filtered_resp
=
[]
for
resp
in
inst
:
for
resp
in
inst
:
...
@@ -156,7 +113,9 @@ class MultiChoiceRegexFilter(RegexFilter):
...
@@ -156,7 +113,9 @@ class MultiChoiceRegexFilter(RegexFilter):
self
.
ignore_punctuation
=
ignore_punctuation
self
.
ignore_punctuation
=
ignore_punctuation
self
.
regexes_to_ignore
=
regexes_to_ignore
self
.
regexes_to_ignore
=
regexes_to_ignore
def
apply
(
self
,
resps
:
list
[
list
[
str
]],
docs
:
list
[
dict
])
->
list
[
list
[
str
]]:
def
apply
(
self
,
resps
:
Iterable
[
list
[
str
]],
docs
:
Iterable
[
dict
]
)
->
Iterable
[
list
[
str
]]:
# here, we assume we have a list, in which each element is
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# so we process each of these (same input/target response sets)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment