Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
04e74420
Commit
04e74420
authored
Jul 03, 2025
by
Baber
Browse files
cleanup
parent
b0173d57
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
93 additions
and
94 deletions
+93
-94
lm_eval/api/model.py
lm_eval/api/model.py
+11
-5
lm_eval/api/registry.py
lm_eval/api/registry.py
+6
-2
lm_eval/api/task.py
lm_eval/api/task.py
+69
-76
lm_eval/filters/decontamination.py
lm_eval/filters/decontamination.py
+2
-1
lm_eval/filters/extraction.py
lm_eval/filters/extraction.py
+4
-0
lm_eval/filters/selection.py
lm_eval/filters/selection.py
+0
-1
lm_eval/filters/transformation.py
lm_eval/filters/transformation.py
+1
-9
No files found.
lm_eval/api/model.py
View file @
04e74420
...
@@ -176,14 +176,14 @@ class LM(abc.ABC):
...
@@ -176,14 +176,14 @@ class LM(abc.ABC):
return
cls
(
**
arg_dict
,
**
additional_config
)
return
cls
(
**
arg_dict
,
**
additional_config
)
@
property
@
property
def
rank
(
self
):
def
rank
(
self
)
->
int
:
# used in the case of parallelism. Hardcoded to
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
# not support multi-device parallelism nor expect it.
return
self
.
_rank
return
self
.
_rank
@
property
@
property
def
world_size
(
self
):
def
world_size
(
self
)
->
int
:
# used in the case of parallelism. Hardcoded to
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
# not support multi-device parallelism nor expect it.
...
@@ -233,7 +233,7 @@ class CacheHook:
...
@@ -233,7 +233,7 @@ class CacheHook:
class
CachingLM
:
class
CachingLM
:
def
__init__
(
self
,
lm
:
LM
,
cache_db
:
str
)
->
None
:
def
__init__
(
self
,
lm
:
"
LM
"
,
cache_db
:
str
)
->
None
:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
:param lm: LM
...
@@ -327,11 +327,11 @@ class TemplateLM(LM):
...
@@ -327,11 +327,11 @@ class TemplateLM(LM):
@
property
@
property
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
eot_token_id
(
self
):
def
eot_token_id
(
self
)
->
int
:
pass
pass
@
property
@
property
def
prefix_token_id
(
self
):
def
prefix_token_id
(
self
)
->
int
:
# it is used as prefix for loglikelihood
# it is used as prefix for loglikelihood
return
self
.
eot_token_id
return
self
.
eot_token_id
...
@@ -351,6 +351,11 @@ class TemplateLM(LM):
...
@@ -351,6 +351,11 @@ class TemplateLM(LM):
def
_encode_pair
(
def
_encode_pair
(
self
,
context
:
str
,
continuation
:
str
self
,
context
:
str
,
continuation
:
str
)
->
tuple
[
list
[
int
],
list
[
int
]]:
)
->
tuple
[
list
[
int
],
list
[
int
]]:
"""Encodes a pair of context and continuation strings into token IDs.
Ensures that encode(context + continuation) == encode(context) + encode(continuation)
"""
import
transformers
import
transformers
n_spaces
=
len
(
context
)
-
len
(
context
.
rstrip
())
n_spaces
=
len
(
context
)
-
len
(
context
.
rstrip
())
...
@@ -402,6 +407,7 @@ class TemplateLM(LM):
...
@@ -402,6 +407,7 @@ class TemplateLM(LM):
def
chat_template
(
self
,
chat_template
:
Union
[
bool
,
str
]
=
False
)
->
Optional
[
str
]:
def
chat_template
(
self
,
chat_template
:
Union
[
bool
,
str
]
=
False
)
->
Optional
[
str
]:
"""
"""
Assumes tokenizer has a chat_template attribute (self.tokenizer.chat_template: dict | str)
Set and get the appropriate chat template for the model.
Set and get the appropriate chat template for the model.
This method sets the tokenizer's chat_template and returns the template string for reproducibility.
This method sets the tokenizer's chat_template and returns the template string for reproducibility.
...
...
lm_eval/api/registry.py
View file @
04e74420
...
@@ -8,6 +8,10 @@ if TYPE_CHECKING:
...
@@ -8,6 +8,10 @@ if TYPE_CHECKING:
eval_logger
=
logging
.
getLogger
(
__name__
)
eval_logger
=
logging
.
getLogger
(
__name__
)
MODEL_REGISTRY
=
{}
MODEL_REGISTRY
=
{}
DEFAULTS
=
{
"model"
:
{
"max_length"
:
2048
},
"tasks"
:
{
"generate_until"
:
{
"max_length"
:
2048
}},
}
def
register_model
(
*
names
):
def
register_model
(
*
names
):
...
@@ -167,7 +171,7 @@ def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callabl
...
@@ -167,7 +171,7 @@ def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callabl
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!"
)
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!"
)
def
is_higher_better
(
metric_name
)
->
Optional
[
bool
]:
def
is_higher_better
(
metric_name
:
str
)
->
Optional
[
bool
]:
try
:
try
:
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
except
KeyError
:
except
KeyError
:
...
@@ -176,7 +180,7 @@ def is_higher_better(metric_name) -> Optional[bool]:
...
@@ -176,7 +180,7 @@ def is_higher_better(metric_name) -> Optional[bool]:
)
)
def
register_filter
(
name
):
def
register_filter
(
name
:
str
):
def
decorate
(
cls
):
def
decorate
(
cls
):
if
name
in
FILTER_REGISTRY
:
if
name
in
FILTER_REGISTRY
:
eval_logger
.
info
(
eval_logger
.
info
(
...
...
lm_eval/api/task.py
View file @
04e74420
...
@@ -9,6 +9,7 @@ from dataclasses import asdict, dataclass, field
...
@@ -9,6 +9,7 @@ from dataclasses import asdict, dataclass, field
from
functools
import
cached_property
from
functools
import
cached_property
from
inspect
import
getsource
from
inspect
import
getsource
from
typing
import
(
from
typing
import
(
TYPE_CHECKING
,
Any
,
Any
,
Dict
,
Dict
,
Iterable
,
Iterable
,
...
@@ -50,6 +51,10 @@ ALL_OUTPUT_TYPES = [
...
@@ -50,6 +51,10 @@ ALL_OUTPUT_TYPES = [
"generate_until"
,
"generate_until"
,
]
]
if
TYPE_CHECKING
:
from
lm_eval.api.filter
import
FilterEnsemble
eval_logger
=
logging
.
getLogger
(
__name__
)
eval_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -81,7 +86,7 @@ class MetricConfig:
...
@@ -81,7 +86,7 @@ class MetricConfig:
return
is_higher_better
(
self
.
name
)
return
is_higher_better
(
self
.
name
)
return
self
.
higher_is_better
return
self
.
higher_is_better
def
c
alcula
te_metric
(
self
,
*
args
,
**
kwargs
)
->
Any
:
def
c
ompu
te_metric
(
self
,
*
args
,
**
kwargs
)
->
Any
:
"""Calculates the metric using the provided function and arguments."""
"""Calculates the metric using the provided function and arguments."""
if
self
.
fn
is
None
:
if
self
.
fn
is
None
:
raise
ValueError
(
f
"Metric function for
{
self
.
name
}
is not defined."
)
raise
ValueError
(
f
"Metric function for
{
self
.
name
}
is not defined."
)
...
@@ -99,7 +104,7 @@ class RepeatConfig:
...
@@ -99,7 +104,7 @@ class RepeatConfig:
"""Encapsulates information about a single repeat."""
"""Encapsulates information about a single repeat."""
repeats
:
int
=
1
repeats
:
int
=
1
metric_fn
:
Optional
[
Callable
]
=
None
metric_fn
:
Optional
[
str
,
Callable
]
=
"pass@N"
kwargs
:
Optional
[
dict
]
=
None
kwargs
:
Optional
[
dict
]
=
None
...
@@ -246,15 +251,15 @@ class TaskConfig(dict):
...
@@ -246,15 +251,15 @@ class TaskConfig(dict):
output_type
:
OutputType
=
"generate_until"
output_type
:
OutputType
=
"generate_until"
generation_kwargs
:
Optional
[
dict
]
=
None
generation_kwargs
:
Optional
[
dict
]
=
None
repeats
:
int
=
1
repeats
:
int
=
1
filter_list
:
Optional
[
Union
[
str
,
lis
t
]]
=
None
filter_list
:
Optional
[
list
[
dic
t
]]
=
None
should_decontaminate
:
bool
=
False
should_decontaminate
:
bool
=
False
doc_to_decontamination_query
:
Optional
[
str
]
=
None
doc_to_decontamination_query
:
Optional
[
str
]
=
None
gen_prefix
:
Optional
[
str
]
=
None
gen_prefix
:
Optional
[
str
]
=
None
metadata
:
Optional
[
dict
]
=
(
metadata
:
Optional
[
dict
]
=
(
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
)
)
_metric_list
=
None
_metric_list
:
list
[
MetricConfig
]
=
None
_filter_list
=
None
_filter_list
:
list
[
FilterConfig
]
=
None
def
__post_init__
(
self
)
->
None
:
def
__post_init__
(
self
)
->
None
:
if
self
.
generation_kwargs
is
not
None
:
if
self
.
generation_kwargs
is
not
None
:
...
@@ -289,16 +294,13 @@ class TaskConfig(dict):
...
@@ -289,16 +294,13 @@ class TaskConfig(dict):
f
"
{
self
.
task
}
: No `generation_kwargs` specified in task config, defaulting to
{
self
.
generation_kwargs
}
"
f
"
{
self
.
task
}
: No `generation_kwargs` specified in task config, defaulting to
{
self
.
generation_kwargs
}
"
)
)
if
self
.
metric_list
is
not
None
:
if
self
.
metric_list
and
not
all
(
"metric"
in
cfg
for
cfg
in
self
.
metric_list
):
for
metric_config
in
self
.
metric_list
:
raise
ValueError
(
"each entry in metric_list must include a 'metric' key"
)
if
"metric"
not
in
metric_config
:
raise
ValueError
(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
def
get_metrics
(
self
)
->
list
[
"MetricConfig"
]:
def
get_metrics
(
self
)
->
list
[
"MetricConfig"
]:
metrics
=
[]
metrics
=
[]
if
self
.
metric_list
is
None
:
if
self
.
metric_list
is
None
:
# ---------- 1. If no metrics defined, use defaults for output type ----------
_metric_list
=
DEFAULT_METRIC_REGISTRY
[
self
.
output_type
]
_metric_list
=
DEFAULT_METRIC_REGISTRY
[
self
.
output_type
]
eval_logger
.
info
(
eval_logger
.
info
(
f
"No metrics defined in config, using default metrics for
{
self
.
output_type
}
=
{
_metric_list
}
"
f
"No metrics defined in config, using default metrics for
{
self
.
output_type
}
=
{
_metric_list
}
"
...
@@ -313,11 +315,8 @@ class TaskConfig(dict):
...
@@ -313,11 +315,8 @@ class TaskConfig(dict):
for
metric_name
in
_metric_list
for
metric_name
in
_metric_list
)
)
else
:
else
:
# ---------- 2. How will the samples be evaluated ----------
for
metric_config
in
self
.
metric_list
:
for
metric_config
in
self
.
metric_list
:
if
"metric"
not
in
metric_config
:
raise
ValueError
(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name
=
metric_config
[
"metric"
]
metric_name
=
metric_config
[
"metric"
]
_metric_fn_kwargs
=
{
_metric_fn_kwargs
=
{
key
:
metric_config
[
key
]
key
:
metric_config
[
key
]
...
@@ -379,34 +378,30 @@ class TaskConfig(dict):
...
@@ -379,34 +378,30 @@ class TaskConfig(dict):
)
)
return
metrics
return
metrics
def
get_filters
(
self
):
def
get_filters
(
self
)
->
list
[
"FilterEnsemble"
]:
if
self
.
filter_list
is
not
None
:
if
not
self
.
filter_list
:
_filter_list
=
[]
if
isinstance
(
self
.
filter_list
,
dict
):
for
filter_config
in
self
.
filter_list
:
_filter_list
.
append
(
build_filter_ensemble
(
filter_name
=
filter_config
[
"name"
],
components
=
[
[
{
key
:
function
[
key
]
for
key
in
function
if
key
!=
"function"
}
]
for
function
in
filter_config
[
"filter"
]
],
)
)
else
:
# TODO: handle repeats in a more general way rather than just discarding
eval_logger
.
debug
(
eval_logger
.
debug
(
"No custom filters defined
. Using default
'take_first'
filter
for handling repeats."
"No custom filters defined
; falling back to
'take_first' for handling repeats."
)
)
_filter_list
=
[
build_filter_ensemble
(
"none"
,
[[
"take_first"
,
None
]])]
return
[
build_filter_ensemble
(
"none"
,
[[
"take_first"
,
None
]])]
else
:
return
_filter_list
def
_strip_fn
(
d
:
dict
)
->
dict
:
return
{
k
:
v
for
k
,
v
in
d
.
items
()
if
k
!=
"function"
}
configs
=
(
self
.
filter_list
.
values
()
if
isinstance
(
self
.
filter_list
,
dict
)
else
self
.
filter_list
)
return
[
build_filter_ensemble
(
filter_name
=
cfg
[
"name"
],
components
=
[[
_strip_fn
(
f
)
for
f
in
cfg
[
"filter"
]]],
)
for
cfg
in
configs
]
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
return
getattr
(
self
,
item
)
...
@@ -415,31 +410,27 @@ class TaskConfig(dict):
...
@@ -415,31 +410,27 @@ class TaskConfig(dict):
return
setattr
(
self
,
item
,
value
)
return
setattr
(
self
,
item
,
value
)
def
to_dict
(
self
,
keep_callable
:
bool
=
False
)
->
dict
:
def
to_dict
(
self
,
keep_callable
:
bool
=
False
)
->
dict
:
"""dumps the current config as a dictionary object, as a printable format.
"""Return a printable dict with Nones stripped and callables serialised.
null fields will not be printed.
Used for dumping results alongside full task configuration
:return: dict
:return: dict
A printable dictionary version of the TaskConfig object.
A printable dictionary version of the TaskConfig object.
# TODO: should any default value in the TaskConfig not be printed?
"""
"""
cfg_dict
=
asdict
(
self
)
# remove values that are `None`
def
_maybe_serialize
(
val
):
for
k
,
v
in
list
(
cfg_dict
.
items
()):
return
(
if
v
is
None
:
self
.
serialize_function
(
val
,
keep_callable
=
keep_callable
)
cfg_dict
.
pop
(
k
)
if
callable
(
val
)
elif
k
==
"metric_list"
:
else
val
for
metric_dict
in
v
:
)
for
metric_key
,
metric_value
in
metric_dict
.
items
():
if
callable
(
metric_value
):
cfg
=
asdict
(
self
)
metric_dict
[
metric_key
]
=
self
.
serialize_function
(
return
{
metric_value
,
keep_callable
=
keep_callable
k
:
[{
mk
:
_maybe_serialize
(
mv
)
for
mk
,
mv
in
md
.
items
()}
for
md
in
v
]
)
if
k
==
"metric_list"
cfg_dict
[
k
]
=
v
else
_maybe_serialize
(
v
)
elif
callable
(
v
):
for
k
,
v
in
cfg
.
items
()
cfg_dict
[
k
]
=
self
.
serialize_function
(
v
,
keep_callable
=
keep_callable
)
if
v
is
not
None
return
cfg_dict
}
def
serialize_function
(
def
serialize_function
(
self
,
value
:
Union
[
Callable
,
str
],
keep_callable
=
False
self
,
value
:
Union
[
Callable
,
str
],
keep_callable
=
False
...
@@ -627,7 +618,7 @@ class Task(abc.ABC):
...
@@ -627,7 +618,7 @@ class Task(abc.ABC):
return
doc
return
doc
@
property
@
property
def
instances
(
self
)
->
L
ist
[
Instance
]:
def
instances
(
self
)
->
l
ist
[
Instance
]:
"""After calling `task.build_all_requests()`, tasks
"""After calling `task.build_all_requests()`, tasks
maintain a list of the dataset instances which will be evaluated.
maintain a list of the dataset instances which will be evaluated.
"""
"""
...
@@ -639,27 +630,27 @@ class Task(abc.ABC):
...
@@ -639,27 +630,27 @@ class Task(abc.ABC):
return
rnd
.
sample
(
self
.
_training_docs
,
k
)
return
rnd
.
sample
(
self
.
_training_docs
,
k
)
def
doc_to_decontamination_query
(
self
,
doc
):
def
doc_to_decontamination_query
(
self
,
doc
:
dict
):
raise
NotImplementedError
(
raise
NotImplementedError
(
"Override doc_to_decontamination_query with document specific decontamination query."
"Override doc_to_decontamination_query with document specific decontamination query."
)
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
doc_to_text
(
self
,
doc
)
->
str
:
def
doc_to_text
(
self
,
doc
:
dict
)
->
str
:
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
doc_to_target
(
self
,
doc
)
->
Union
[
str
,
int
]:
def
doc_to_target
(
self
,
doc
:
dict
)
->
Union
[
str
,
int
]:
pass
pass
# not an abstractmethod because not every language-only task has to implement this
# not an abstractmethod because not every language-only task has to implement this
def
doc_to_image
(
self
,
doc
):
def
doc_to_image
(
self
,
doc
:
dict
):
raise
NotImplementedError
raise
NotImplementedError
def
doc_to_audio
(
self
,
doc
):
def
doc_to_audio
(
self
,
doc
:
dict
):
raise
NotImplementedError
raise
NotImplementedError
def
doc_to_prefix
(
self
,
doc
)
->
str
:
def
doc_to_prefix
(
self
,
doc
:
dict
)
->
str
:
return
""
return
""
def
build_all_requests
(
def
build_all_requests
(
...
@@ -776,7 +767,7 @@ class Task(abc.ABC):
...
@@ -776,7 +767,7 @@ class Task(abc.ABC):
save_to_cache
(
file_name
=
cache_key
,
obj
=
instances
)
save_to_cache
(
file_name
=
cache_key
,
obj
=
instances
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
construct_requests
(
self
,
doc
,
ctx
,
**
kwargs
):
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
Union
[
list
[
dict
],
str
]
,
**
kwargs
):
"""Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
...
@@ -797,7 +788,7 @@ class Task(abc.ABC):
...
@@ -797,7 +788,7 @@ class Task(abc.ABC):
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
:
dict
,
results
:
list
):
"""Take a single document and the LM results and evaluates, returning a
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
dict where keys are the names of submetrics and values are the values of
the metric for that one document
the metric for that one document
...
@@ -1450,7 +1441,7 @@ class ConfigurableTask(Task):
...
@@ -1450,7 +1441,7 @@ class ConfigurableTask(Task):
"""
"""
return
doc
return
doc
def
doc_to_text
(
self
,
doc
,
doc_to_text
=
None
):
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
Optional
[
int
,
str
,
Callable
]
=
None
):
if
self
.
prompt
is
not
None
:
if
self
.
prompt
is
not
None
:
doc_to_text
=
self
.
prompt
doc_to_text
=
self
.
prompt
elif
doc_to_text
is
not
None
:
elif
doc_to_text
is
not
None
:
...
@@ -1486,7 +1477,7 @@ class ConfigurableTask(Task):
...
@@ -1486,7 +1477,7 @@ class ConfigurableTask(Task):
print
(
type
(
doc_to_text
))
print
(
type
(
doc_to_text
))
raise
TypeError
raise
TypeError
def
doc_to_target
(
self
,
doc
:
Mapping
,
doc_to_target
=
None
)
->
Union
[
int
,
str
,
list
]:
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
=
None
)
->
Union
[
int
,
str
,
list
]:
if
self
.
prompt
is
not
None
:
if
self
.
prompt
is
not
None
:
doc_to_target
=
self
.
prompt
doc_to_target
=
self
.
prompt
elif
doc_to_target
is
not
None
:
elif
doc_to_target
is
not
None
:
...
@@ -1532,7 +1523,9 @@ class ConfigurableTask(Task):
...
@@ -1532,7 +1523,9 @@ class ConfigurableTask(Task):
else
:
else
:
raise
TypeError
raise
TypeError
def
doc_to_choice
(
self
,
doc
:
Any
,
doc_to_choice
=
None
)
->
List
[
str
]:
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
Union
[
str
,
list
,
dict
]
=
None
)
->
List
[
str
]:
if
self
.
prompt
is
not
None
:
if
self
.
prompt
is
not
None
:
doc_to_choice
=
self
.
prompt
doc_to_choice
=
self
.
prompt
elif
doc_to_choice
is
not
None
:
elif
doc_to_choice
is
not
None
:
...
@@ -1558,7 +1551,7 @@ class ConfigurableTask(Task):
...
@@ -1558,7 +1551,7 @@ class ConfigurableTask(Task):
else
:
else
:
raise
TypeError
raise
TypeError
def
doc_to_image
(
self
,
doc
:
Any
,
doc_to_image
=
None
)
->
Union
[
int
,
str
,
list
,
None
]:
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
=
None
)
->
Union
[
int
,
str
,
list
,
None
]:
if
doc_to_image
is
not
None
:
if
doc_to_image
is
not
None
:
doc_to_image
=
doc_to_image
doc_to_image
=
doc_to_image
elif
self
.
config
.
doc_to_image
is
not
None
:
elif
self
.
config
.
doc_to_image
is
not
None
:
...
@@ -1604,7 +1597,7 @@ class ConfigurableTask(Task):
...
@@ -1604,7 +1597,7 @@ class ConfigurableTask(Task):
else
:
else
:
return
None
return
None
def
doc_to_prefix
(
self
,
doc
)
->
Optional
[
str
]:
def
doc_to_prefix
(
self
,
doc
:
dict
)
->
Optional
[
str
]:
if
(
gen_prefix
:
=
self
.
config
.
gen_prefix
)
is
not
None
:
if
(
gen_prefix
:
=
self
.
config
.
gen_prefix
)
is
not
None
:
if
gen_prefix
in
self
.
features
:
if
gen_prefix
in
self
.
features
:
return
doc
[
gen_prefix
]
return
doc
[
gen_prefix
]
...
@@ -1713,7 +1706,7 @@ class ConfigurableTask(Task):
...
@@ -1713,7 +1706,7 @@ class ConfigurableTask(Task):
**
kwargs
,
**
kwargs
,
)
)
def
process_results
(
self
,
doc
,
results
)
:
def
process_results
(
self
,
doc
:
dict
,
results
:
list
)
->
dict
:
if
callable
(
self
.
config
.
process_results
):
if
callable
(
self
.
config
.
process_results
):
return
self
.
config
.
process_results
(
doc
,
results
)
return
self
.
config
.
process_results
(
doc
,
results
)
...
...
lm_eval/filters/decontamination.py
View file @
04e74420
...
@@ -10,12 +10,13 @@ class DecontaminationFilter(Filter):
...
@@ -10,12 +10,13 @@ class DecontaminationFilter(Filter):
name
=
"track_decontamination"
name
=
"track_decontamination"
def
__init__
(
self
,
path
)
->
None
:
def
__init__
(
self
,
path
,
**
kwargs
)
->
None
:
"""
"""
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
should further cache result on a given (task_name, doc_id)
should further cache result on a given (task_name, doc_id)
"""
"""
super
().
__init__
(
**
kwargs
)
self
.
_decontam_results
=
None
self
.
_decontam_results
=
None
def
apply
(
self
,
resps
,
docs
)
->
None
:
def
apply
(
self
,
resps
,
docs
)
->
None
:
...
...
lm_eval/filters/extraction.py
View file @
04e74420
...
@@ -20,11 +20,13 @@ class RegexFilter(Filter):
...
@@ -20,11 +20,13 @@ class RegexFilter(Filter):
regex_pattern
:
str
=
r
"#### (\-?[0-9\.\,]+)"
,
regex_pattern
:
str
=
r
"#### (\-?[0-9\.\,]+)"
,
group_select
:
int
=
0
,
group_select
:
int
=
0
,
fallback
:
str
=
"[invalid]"
,
fallback
:
str
=
"[invalid]"
,
**
kwargs
,
)
->
None
:
)
->
None
:
"""
"""
pass a string `regex` to run `re.compile(r"regex")` on.
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
`fallback` defines the output returned if no matches for the regex are located.
"""
"""
super
().
__init__
(
**
kwargs
)
self
.
regex_pattern
=
regex_pattern
self
.
regex_pattern
=
regex_pattern
self
.
regex
=
re
.
compile
(
regex_pattern
)
self
.
regex
=
re
.
compile
(
regex_pattern
)
self
.
group_select
=
group_select
self
.
group_select
=
group_select
...
@@ -66,11 +68,13 @@ class POSFilter(Filter):
...
@@ -66,11 +68,13 @@ class POSFilter(Filter):
regex_pattern
:
str
=
r
"\['(.*?)'\]"
,
regex_pattern
:
str
=
r
"\['(.*?)'\]"
,
group_select
=
0
,
group_select
=
0
,
fallback
=
None
,
fallback
=
None
,
**
kwargs
,
)
->
None
:
)
->
None
:
"""
"""
pass a string `regex` to run `re.compile(r"regex")` on.
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
`fallback` defines the output returned if no matches for the regex are located.
"""
"""
super
().
__init__
(
**
kwargs
)
if
fallback
is
None
:
if
fallback
is
None
:
fallback
=
[
"invalid"
]
fallback
=
[
"invalid"
]
self
.
regex_pattern
=
regex_pattern
self
.
regex_pattern
=
regex_pattern
...
...
lm_eval/filters/selection.py
View file @
04e74420
...
@@ -27,7 +27,6 @@ class TakeFirstFilter(Filter):
...
@@ -27,7 +27,6 @@ class TakeFirstFilter(Filter):
class
TakeKFilter
(
Filter
):
class
TakeKFilter
(
Filter
):
def
__init__
(
self
,
**
kwargs
)
->
None
:
def
__init__
(
self
,
**
kwargs
)
->
None
:
self
.
k
=
kwargs
.
pop
(
"k"
)
self
.
k
=
kwargs
.
pop
(
"k"
)
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
apply
(
self
,
resps
,
docs
):
def
apply
(
self
,
resps
,
docs
):
...
...
lm_eval/filters/transformation.py
View file @
04e74420
...
@@ -6,9 +6,6 @@ from lm_eval.api.registry import register_filter
...
@@ -6,9 +6,6 @@ from lm_eval.api.registry import register_filter
@
register_filter
(
"lowercase"
)
@
register_filter
(
"lowercase"
)
class
LowercaseFilter
(
Filter
):
class
LowercaseFilter
(
Filter
):
def
__init__
(
self
)
->
None
:
pass
def
apply
(
self
,
resps
,
docs
):
def
apply
(
self
,
resps
,
docs
):
def
filter_set
(
inst
):
def
filter_set
(
inst
):
return
[
resp
.
lower
()
for
resp
in
inst
]
return
[
resp
.
lower
()
for
resp
in
inst
]
...
@@ -18,9 +15,6 @@ class LowercaseFilter(Filter):
...
@@ -18,9 +15,6 @@ class LowercaseFilter(Filter):
@
register_filter
(
"uppercase"
)
@
register_filter
(
"uppercase"
)
class
UppercaseFilter
(
Filter
):
class
UppercaseFilter
(
Filter
):
def
__init__
(
self
)
->
None
:
pass
def
apply
(
self
,
resps
,
docs
):
def
apply
(
self
,
resps
,
docs
):
def
filter_set
(
inst
):
def
filter_set
(
inst
):
return
[
resp
.
upper
()
for
resp
in
inst
]
return
[
resp
.
upper
()
for
resp
in
inst
]
...
@@ -31,6 +25,7 @@ class UppercaseFilter(Filter):
...
@@ -31,6 +25,7 @@ class UppercaseFilter(Filter):
@
register_filter
(
"map"
)
@
register_filter
(
"map"
)
class
MapFilter
(
Filter
):
class
MapFilter
(
Filter
):
def
__init__
(
self
,
mapping_dict
:
dict
=
None
,
default_value
=
None
)
->
None
:
def
__init__
(
self
,
mapping_dict
:
dict
=
None
,
default_value
=
None
)
->
None
:
super
().
__init__
()
"""
"""
Initializes the MapFilter with a given mapping dictionary and default value.
Initializes the MapFilter with a given mapping dictionary and default value.
...
@@ -60,9 +55,6 @@ class MapFilter(Filter):
...
@@ -60,9 +55,6 @@ class MapFilter(Filter):
@
register_filter
(
"format_span"
)
@
register_filter
(
"format_span"
)
class
SPANFilter
(
Filter
):
class
SPANFilter
(
Filter
):
def
__init__
(
self
)
->
None
:
pass
def
apply
(
self
,
resps
,
docs
):
def
apply
(
self
,
resps
,
docs
):
def
format_ner_text
(
text
):
def
format_ner_text
(
text
):
label_dict
=
{
label_dict
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment