Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
ec767666
Commit
ec767666
authored
Jul 23, 2025
by
Baber
Browse files
overload Task methods if callable in yaml dict
parent
2009ec4b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
80 additions
and
73 deletions
+80
-73
lm_eval/api/filter.py
lm_eval/api/filter.py
+3
-3
lm_eval/api/task.py
lm_eval/api/task.py
+53
-64
lm_eval/config/task.py
lm_eval/config/task.py
+4
-2
lm_eval/config/utils.py
lm_eval/config/utils.py
+20
-4
No files found.
lm_eval/api/filter.py
View file @
ec767666
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Protocol
,
runtime_checkable
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.instance
import
Instance
class
Filter
(
ABC
):
@
runtime_checkable
class
Filter
(
Protocol
):
"""
"""
Filter classes operate on a per-task level.
Filter classes operate on a per-task level.
They take all model outputs (`instance.resps` for all `task.instances`)
They take all model outputs (`instance.resps` for all `task.instances`)
...
@@ -19,7 +20,6 @@ class Filter(ABC):
...
@@ -19,7 +20,6 @@ class Filter(ABC):
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
"""
@
abstractmethod
def
apply
(
def
apply
(
self
,
resps
:
Iterable
[
list
[
str
]],
docs
:
Iterable
[
dict
]
self
,
resps
:
Iterable
[
list
[
str
]],
docs
:
Iterable
[
dict
]
)
->
Iterable
[
list
[
str
]]:
)
->
Iterable
[
list
[
str
]]:
...
...
lm_eval/api/task.py
View file @
ec767666
...
@@ -7,6 +7,8 @@ import random
...
@@ -7,6 +7,8 @@ import random
import
re
import
re
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
copy
import
deepcopy
from
copy
import
deepcopy
from
functools
import
cached_property
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
overload
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
overload
import
datasets
import
datasets
...
@@ -143,14 +145,17 @@ class Task(abc.ABC):
...
@@ -143,14 +145,17 @@ class Task(abc.ABC):
"""Returns the TaskConfig associated with this class."""
"""Returns the TaskConfig associated with this class."""
return
self
.
_config
return
self
.
_config
@
property
def
has_training_docs
(
self
)
->
bool
:
def
has_training_docs
(
self
)
->
bool
:
"""Whether the task has a training set"""
"""Whether the task has a training set"""
raise
NotImplementedError
raise
NotImplementedError
@
property
def
has_validation_docs
(
self
)
->
bool
:
def
has_validation_docs
(
self
)
->
bool
:
"""Whether the task has a validation set"""
"""Whether the task has a validation set"""
raise
NotImplementedError
raise
NotImplementedError
@
property
def
has_test_docs
(
self
)
->
bool
:
def
has_test_docs
(
self
)
->
bool
:
"""Whether the task has a test set"""
"""Whether the task has a test set"""
raise
NotImplementedError
raise
NotImplementedError
...
@@ -181,9 +186,9 @@ class Task(abc.ABC):
...
@@ -181,9 +186,9 @@ class Task(abc.ABC):
:return: Iterable[obj]
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
A iterable of any object, that doc_to_text can handle
"""
"""
if
self
.
has_training_docs
()
:
if
self
.
has_training_docs
:
return
self
.
training_docs
()
return
self
.
training_docs
()
elif
self
.
has_validation_docs
()
:
elif
self
.
has_validation_docs
:
return
self
.
validation_docs
()
return
self
.
validation_docs
()
else
:
else
:
if
self
.
config
.
num_fewshot
and
self
.
config
.
num_fewshot
>
0
:
if
self
.
config
.
num_fewshot
and
self
.
config
.
num_fewshot
>
0
:
...
@@ -211,7 +216,7 @@ class Task(abc.ABC):
...
@@ -211,7 +216,7 @@ class Task(abc.ABC):
"""
"""
return
self
.
_instances
return
self
.
_instances
def
fewshot_examples
(
self
,
k
,
rnd
)
->
Iterable
[
dict
]:
def
fewshot_examples
(
self
,
k
:
int
,
rnd
)
->
Iterable
[
dict
]:
if
self
.
_training_docs
is
None
:
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
training_docs
())
self
.
_training_docs
=
list
(
self
.
training_docs
())
...
@@ -449,13 +454,13 @@ class Task(abc.ABC):
...
@@ -449,13 +454,13 @@ class Task(abc.ABC):
labeled_examples
=
""
labeled_examples
=
""
else
:
else
:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if
self
.
has_training_docs
()
:
if
self
.
has_training_docs
:
fewshotex
=
self
.
fewshot_examples
(
k
=
num_fewshot
,
rnd
=
rnd
)
fewshotex
=
self
.
fewshot_examples
(
k
=
num_fewshot
,
rnd
=
rnd
)
else
:
else
:
if
self
.
_fewshot_docs
is
None
:
if
self
.
_fewshot_docs
is
None
:
self
.
_fewshot_docs
=
list
(
self
.
_fewshot_docs
=
list
(
self
.
validation_docs
()
self
.
validation_docs
()
if
self
.
has_validation_docs
()
if
self
.
has_validation_docs
else
self
.
test_docs
()
else
self
.
test_docs
()
)
)
...
@@ -528,9 +533,9 @@ class Task(abc.ABC):
...
@@ -528,9 +533,9 @@ class Task(abc.ABC):
@
property
@
property
def
eval_docs
(
self
)
->
datasets
.
Dataset
|
Iterable
[
dict
]:
def
eval_docs
(
self
)
->
datasets
.
Dataset
|
Iterable
[
dict
]:
if
self
.
has_test_docs
()
:
if
self
.
has_test_docs
:
return
self
.
test_docs
()
return
self
.
test_docs
()
elif
self
.
has_validation_docs
()
:
elif
self
.
has_validation_docs
:
return
self
.
validation_docs
()
return
self
.
validation_docs
()
else
:
else
:
raise
ValueError
(
raise
ValueError
(
...
@@ -587,7 +592,7 @@ class ConfigurableTask(Task):
...
@@ -587,7 +592,7 @@ class ConfigurableTask(Task):
# Use new configurations if there was no preconfiguration
# Use new configurations if there was no preconfiguration
if
self
.
config
is
None
:
if
self
.
config
is
None
:
self
.
_config
=
TaskConfig
(
**
config
)
self
.
_config
=
TaskConfig
.
from_yaml
(
config
)
# Overwrite configs
# Overwrite configs
else
:
else
:
if
config
is
not
None
:
if
config
is
not
None
:
...
@@ -730,17 +735,20 @@ class ConfigurableTask(Task):
...
@@ -730,17 +735,20 @@ class ConfigurableTask(Task):
**
self
.
config
.
dataset_kwargs
,
**
self
.
config
.
dataset_kwargs
,
)
)
@
cached_property
def
has_training_docs
(
self
)
->
bool
:
def
has_training_docs
(
self
)
->
bool
:
return
self
.
config
.
training_split
is
not
None
return
self
.
config
.
training_split
is
not
None
@
cached_property
def
has_validation_docs
(
self
)
->
bool
:
def
has_validation_docs
(
self
)
->
bool
:
return
self
.
config
.
validation_split
is
not
None
return
self
.
config
.
validation_split
is
not
None
@
cached_property
def
has_test_docs
(
self
)
->
bool
:
def
has_test_docs
(
self
)
->
bool
:
return
self
.
config
.
test_split
is
not
None
return
self
.
config
.
test_split
is
not
None
def
training_docs
(
self
)
->
DataSet
|
None
:
def
training_docs
(
self
)
->
DataSet
|
None
:
if
self
.
has_training_docs
()
:
if
self
.
has_training_docs
:
if
self
.
config
.
process_docs
is
not
None
:
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
return
self
.
config
.
process_docs
(
self
.
dataset
[
self
.
config
.
training_split
]
self
.
dataset
[
self
.
config
.
training_split
]
...
@@ -748,7 +756,7 @@ class ConfigurableTask(Task):
...
@@ -748,7 +756,7 @@ class ConfigurableTask(Task):
return
self
.
dataset
[
self
.
config
.
training_split
]
return
self
.
dataset
[
self
.
config
.
training_split
]
def
validation_docs
(
self
)
->
DataSet
|
None
:
def
validation_docs
(
self
)
->
DataSet
|
None
:
if
self
.
has_validation_docs
()
:
if
self
.
has_validation_docs
:
if
self
.
config
.
process_docs
is
not
None
:
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
return
self
.
config
.
process_docs
(
self
.
dataset
[
self
.
config
.
validation_split
]
self
.
dataset
[
self
.
config
.
validation_split
]
...
@@ -756,7 +764,7 @@ class ConfigurableTask(Task):
...
@@ -756,7 +764,7 @@ class ConfigurableTask(Task):
return
self
.
dataset
[
self
.
config
.
validation_split
]
return
self
.
dataset
[
self
.
config
.
validation_split
]
def
test_docs
(
self
)
->
DataSet
|
None
:
def
test_docs
(
self
)
->
DataSet
|
None
:
if
self
.
has_test_docs
()
:
if
self
.
has_test_docs
:
if
self
.
config
.
process_docs
is
not
None
:
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
self
.
dataset
[
self
.
config
.
test_split
])
return
self
.
config
.
process_docs
(
self
.
dataset
[
self
.
config
.
test_split
])
return
self
.
dataset
[
self
.
config
.
test_split
]
return
self
.
dataset
[
self
.
config
.
test_split
]
...
@@ -1011,23 +1019,16 @@ class ConfigurableTask(Task):
...
@@ -1011,23 +1019,16 @@ class ConfigurableTask(Task):
# if self.prompt is not None:
# if self.prompt is not None:
# doc_to_text = self.prompt
# doc_to_text = self.prompt
doc_to_text
=
doc_to_text
or
self
.
config
.
doc_to_text
doc_to_text
=
doc_to_text
or
self
.
config
.
doc_to_text
if
doc_to_text
in
doc
:
if
isinstance
(
doc_to_text
,
int
):
return
doc
[
doc_to_text
]
return
doc_to_text
elif
isinstance
(
doc_to_text
,
str
):
elif
isinstance
(
doc_to_text
,
str
):
if
doc_to_text
in
self
.
features
:
text_string
=
utils
.
apply_template
(
doc_to_text
,
doc
)
# if self.config.doc_to_choice is not None:
if
text_string
.
isdigit
()
and
self
.
config
.
doc_to_choice
is
not
None
:
# return self.doc_to_choice(doc)[doc[doc_to_text]]
return
ast
.
literal_eval
(
text_string
)
# else:
return
doc
[
doc_to_text
]
else
:
else
:
text_string
=
utils
.
apply_template
(
doc_to_text
,
doc
)
return
text_string
if
text_string
.
isdigit
()
and
self
.
config
.
doc_to_choice
is
not
None
:
elif
isinstance
(
doc_to_text
,
int
):
return
ast
.
literal_eval
(
text_string
)
return
doc_to_text
else
:
return
text_string
elif
callable
(
doc_to_text
):
return
doc_to_text
(
doc
)
# Used when applying a Promptsource template
# Used when applying a Promptsource template
# elif hasattr(doc_to_text, "apply"):
# elif hasattr(doc_to_text, "apply"):
# applied_prompt = doc_to_text.apply(doc)
# applied_prompt = doc_to_text.apply(doc)
...
@@ -1062,38 +1063,31 @@ class ConfigurableTask(Task):
...
@@ -1062,38 +1063,31 @@ class ConfigurableTask(Task):
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
=
None
)
->
int
|
str
|
list
[
int
]:
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
=
None
)
->
int
|
str
|
list
[
int
]:
# if self.prompt is not None:
# if self.prompt is not None:
# doc_to_target = self.prompt
# doc_to_target = self.prompt
if
doc_to_target
is
not
None
:
doc_to_target
=
doc_to_target
or
self
.
config
.
doc_to_target
doc_to_target
=
doc_to_target
if
doc_to_target
in
doc
:
else
:
return
doc
[
doc_to_target
]
doc_to_target
=
self
.
config
.
doc_to_target
if
isinstance
(
doc_to_target
,
int
):
return
doc_to_target
elif
isinstance
(
doc_to_target
,
str
):
elif
isinstance
(
doc_to_target
,
str
):
if
doc_to_target
in
self
.
features
:
target_string
=
utils
.
apply_template
(
doc_to_target
,
doc
)
# if self.config.doc_to_choice is not None:
if
target_string
.
isdigit
()
and
self
.
config
.
doc_to_choice
is
not
None
:
# return self.doc_to_choice(doc)[doc[doc_to_target]]
return
ast
.
literal_eval
(
target_string
)
# else:
# elif (
return
doc
[
doc_to_target
]
# len(target_string) >= 2
# and (target_string[0] == "[")
# and (target_string[-1] == "]")
# ):
# try:
# return ast.literal_eval(target_string)
# except (SyntaxError, ValueError):
# return target_string
else
:
else
:
target_string
=
utils
.
apply_template
(
doc_to_target
,
doc
)
return
target_string
if
target_string
.
isdigit
()
and
self
.
config
.
doc_to_choice
is
not
None
:
return
ast
.
literal_eval
(
target_string
)
elif
isinstance
(
doc_to_target
,
(
int
,
list
)):
elif
(
len
(
target_string
)
>=
2
and
(
target_string
[
0
]
==
"["
)
and
(
target_string
[
-
1
]
==
"]"
)
):
try
:
return
ast
.
literal_eval
(
target_string
)
except
(
SyntaxError
,
ValueError
):
return
target_string
else
:
return
target_string
elif
isinstance
(
doc_to_target
,
list
):
return
doc_to_target
return
doc_to_target
elif
callable
(
doc_to_target
):
# elif isinstance(doc_to_target, list):
return
doc_to_target
(
doc
)
# return doc_to_target
# elif callable(doc_to_target):
# return doc_to_target(doc)
# # Used when applying a Promptsource template
# # Used when applying a Promptsource template
# elif hasattr(doc_to_target, "apply"):
# elif hasattr(doc_to_target, "apply"):
# applied_prompt = doc_to_target.apply(doc)
# applied_prompt = doc_to_target.apply(doc)
...
@@ -1138,16 +1132,14 @@ class ConfigurableTask(Task):
...
@@ -1138,16 +1132,14 @@ class ConfigurableTask(Task):
doc_to_choice
=
self
.
config
.
doc_to_choice
doc_to_choice
=
self
.
config
.
doc_to_choice
if
isinstance
(
doc_to_choice
,
str
):
if
isinstance
(
doc_to_choice
,
str
):
if
doc_to_choice
in
self
.
features
:
if
doc_to_choice
in
doc
:
return
doc
[
doc_to_choice
]
return
doc
[
doc_to_choice
]
else
:
else
:
return
ast
.
literal_eval
(
utils
.
apply_template
(
doc_to_choice
,
doc
))
return
ast
.
literal_eval
(
utils
.
apply_template
(
doc_to_choice
,
doc
))
elif
isinstance
(
doc_to_choice
,
list
):
elif
isinstance
(
doc_to_choice
,
list
):
return
doc_to_choice
return
doc_to_choice
elif
isinstance
(
doc_to_choice
,
dict
):
# elif isinstance(doc_to_choice, dict):
return
list
(
doc_to_choice
.
values
())
# return list(doc_to_choice.values())
elif
callable
(
doc_to_choice
):
return
doc_to_choice
(
doc
)
# elif hasattr(doc_to_choice, "get_answer_choices_list"):
# elif hasattr(doc_to_choice, "get_answer_choices_list"):
# return doc_to_choice.get_answer_choices_list(doc)
# return doc_to_choice.get_answer_choices_list(doc)
else
:
else
:
...
@@ -1225,7 +1217,7 @@ class ConfigurableTask(Task):
...
@@ -1225,7 +1217,7 @@ class ConfigurableTask(Task):
def
doc_to_prefix
(
self
,
doc
:
dict
)
->
str
|
None
:
def
doc_to_prefix
(
self
,
doc
:
dict
)
->
str
|
None
:
if
(
gen_prefix
:
=
self
.
config
.
gen_prefix
)
is
not
None
:
if
(
gen_prefix
:
=
self
.
config
.
gen_prefix
)
is
not
None
:
if
gen_prefix
in
self
.
features
:
if
gen_prefix
in
doc
:
return
doc
[
gen_prefix
]
return
doc
[
gen_prefix
]
else
:
else
:
return
utils
.
apply_template
(
gen_prefix
,
doc
)
return
utils
.
apply_template
(
gen_prefix
,
doc
)
...
@@ -1333,9 +1325,6 @@ class ConfigurableTask(Task):
...
@@ -1333,9 +1325,6 @@ class ConfigurableTask(Task):
)
)
def
process_results
(
self
,
doc
:
dict
,
results
:
list
)
->
dict
[
str
,
Any
]:
def
process_results
(
self
,
doc
:
dict
,
results
:
list
)
->
dict
[
str
,
Any
]:
if
callable
(
self
.
config
.
process_results
):
return
self
.
config
.
process_results
(
doc
,
results
)
result_dict
=
{}
result_dict
=
{}
use_metric
=
list
(
m
.
metric_name
for
m
in
self
.
config
.
_metric_list
)
use_metric
=
list
(
m
.
metric_name
for
m
in
self
.
config
.
_metric_list
)
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
...
...
lm_eval/config/task.py
View file @
ec767666
...
@@ -10,7 +10,7 @@ import datasets
...
@@ -10,7 +10,7 @@ import datasets
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.instance
import
OutputType
from
lm_eval.api.instance
import
OutputType
from
lm_eval.config.metric
import
MetricConfig
from
lm_eval.config.metric
import
MetricConfig
from
lm_eval.config.utils
import
maybe_serialize
from
lm_eval.config.utils
import
doc_to_closure
,
maybe_serialize
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -179,6 +179,7 @@ class TaskConfig:
...
@@ -179,6 +179,7 @@ class TaskConfig:
_filter_list
:
list
[
FilterConfig
]
=
field
(
default_factory
=
list
)
_filter_list
:
list
[
FilterConfig
]
=
field
(
default_factory
=
list
)
# ds_cfg: DatasetConfig = field(init=False)
# ds_cfg: DatasetConfig = field(init=False)
fewshot_cfg
:
FewshotConfig
=
field
(
init
=
False
)
fewshot_cfg
:
FewshotConfig
=
field
(
init
=
False
)
_fn
:
dict
[
str
,
Callable
]
=
field
(
default_factory
=
dict
)
def
__post_init__
(
self
)
->
None
:
def
__post_init__
(
self
)
->
None
:
### ---setup generation kwargs--- ###
### ---setup generation kwargs--- ###
...
@@ -363,7 +364,8 @@ class TaskConfig:
...
@@ -363,7 +364,8 @@ class TaskConfig:
@
classmethod
@
classmethod
def
from_yaml
(
cls
,
data
:
dict
[
str
,
Any
])
->
TaskConfig
:
def
from_yaml
(
cls
,
data
:
dict
[
str
,
Any
])
->
TaskConfig
:
"""Create a TaskConfig instance from a YAML-like dictionary."""
"""Create a TaskConfig instance from a YAML-like dictionary."""
return
cls
(
**
data
)
fn
=
{
k
:
doc_to_closure
(
v
)
for
k
,
v
in
data
.
items
()
if
callable
(
v
)}
return
cls
(
**
data
,
_fn
=
fn
)
@
classmethod
@
classmethod
def
from_template
(
cls
,
template
:
TemplateConfig
,
**
kwargs
)
->
TaskConfig
:
def
from_template
(
cls
,
template
:
TemplateConfig
,
**
kwargs
)
->
TaskConfig
:
...
...
lm_eval/config/utils.py
View file @
ec767666
from
__future__
import
annotations
from
__future__
import
annotations
from
functools
import
wraps
from
inspect
import
getsource
from
inspect
import
getsource
from
typing
import
Any
,
Callable
from
typing
import
Any
,
Callable
,
TypeVar
T
=
TypeVar
(
"T"
)
def
serialize_callable
(
def
serialize_callable
(
value
:
Callable
[...,
Any
]
|
str
,
keep_callable
=
False
value
:
Callable
[...,
T
]
|
str
,
keep_callable
=
False
)
->
Callable
[...,
Any
]
|
str
:
)
->
Callable
[...,
T
]
|
str
:
"""Serializes a given function or string.
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
If 'keep_callable' is True, the original callable is returned.
...
@@ -22,7 +26,9 @@ def serialize_callable(
...
@@ -22,7 +26,9 @@ def serialize_callable(
return
str
(
value
)
return
str
(
value
)
def
maybe_serialize
(
val
:
Callable
|
Any
,
keep_callable
=
False
)
->
Callable
|
Any
:
def
maybe_serialize
(
val
:
Callable
[...,
T
]
|
Any
,
keep_callable
=
False
)
->
Callable
[...,
T
]
|
Any
:
"""Conditionally serializes a value if it is callable."""
"""Conditionally serializes a value if it is callable."""
return
(
return
(
...
@@ -41,3 +47,13 @@ def create_mc_choices(choices: list[str], choice_delimiter: str | None = "\n") -
...
@@ -41,3 +47,13 @@ def create_mc_choices(choices: list[str], choice_delimiter: str | None = "\n") -
formatted_choices
=
[
f
"
{
chr
(
65
+
i
)
}
.
{
choice
}
"
for
i
,
choice
in
enumerate
(
choices
)]
formatted_choices
=
[
f
"
{
chr
(
65
+
i
)
}
.
{
choice
}
"
for
i
,
choice
in
enumerate
(
choices
)]
return
choice_delimiter
.
join
(
formatted_choices
)
return
choice_delimiter
.
join
(
formatted_choices
)
def
doc_to_closure
(
fn
:
Callable
[...,
T
])
->
Callable
[...,
T
]:
"""Closure that allows the function to be called with 'self'."""
@
wraps
(
fn
)
def
closure
(
self
:
Any
,
*
args
,
**
kwargs
):
return
fn
(
*
args
,
**
kwargs
)
return
closure
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment