Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
69d14fb3
Commit
69d14fb3
authored
Jul 21, 2025
by
Baber
Browse files
cleanup
parent
57b8c0b1
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
201 additions
and
259 deletions
+201
-259
lm_eval/api/filter.py
lm_eval/api/filter.py
+6
-4
lm_eval/api/registry.py
lm_eval/api/registry.py
+15
-12
lm_eval/api/task.py
lm_eval/api/task.py
+85
-111
lm_eval/config/metric.py
lm_eval/config/metric.py
+12
-9
lm_eval/config/task.py
lm_eval/config/task.py
+46
-47
lm_eval/config/template.py
lm_eval/config/template.py
+21
-19
lm_eval/config/utils.py
lm_eval/config/utils.py
+6
-6
lm_eval/filters/extraction.py
lm_eval/filters/extraction.py
+10
-51
No files found.
lm_eval/api/filter.py
View file @
69d14fb3
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Iterable
,
List
,
Union
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.instance
import
Instance
...
@@ -20,7 +20,9 @@ class Filter(ABC):
...
@@ -20,7 +20,9 @@ class Filter(ABC):
"""
"""
@
abstractmethod
@
abstractmethod
def
apply
(
self
,
resps
:
Union
[
List
,
Iterable
],
docs
:
List
[
dict
])
->
Iterable
:
def
apply
(
self
,
resps
:
Iterable
[
list
[
str
]],
docs
:
Iterable
[
dict
]
)
->
Iterable
[
list
[
str
]]:
"""
"""
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
...
@@ -40,9 +42,9 @@ class FilterEnsemble:
...
@@ -40,9 +42,9 @@ class FilterEnsemble:
"""
"""
name
:
str
name
:
str
filters
:
L
ist
[
type
[
Filter
]]
filters
:
l
ist
[
type
[
Filter
]]
def
apply
(
self
,
instances
:
L
ist
[
Instance
])
->
None
:
def
apply
(
self
,
instances
:
l
ist
[
Instance
])
->
None
:
resps
,
docs
=
zip
(
*
((
inst
.
resps
,
inst
.
doc
)
for
inst
in
instances
))
resps
,
docs
=
zip
(
*
((
inst
.
resps
,
inst
.
doc
)
for
inst
in
instances
))
resps
,
docs
=
list
(
resps
),
list
(
docs
)
resps
,
docs
=
list
(
resps
),
list
(
docs
)
...
...
lm_eval/api/registry.py
View file @
69d14fb3
from
__future__
import
annotations
import
logging
import
logging
from
typing
import
TYPE_CHECKING
,
Callable
,
Dict
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -36,13 +38,14 @@ def register_model(*names):
...
@@ -36,13 +38,14 @@ def register_model(*names):
return
decorate
return
decorate
def
get_model
(
model_name
:
str
)
->
type
[
"
LM
"
]:
def
get_model
(
model_name
:
str
)
->
type
[
LM
]:
try
:
try
:
return
MODEL_REGISTRY
[
model_name
]
return
MODEL_REGISTRY
[
model_name
]
except
KeyError
:
except
KeyError
as
err
:
raise
ValueError
(
available_models
=
", "
.
join
(
MODEL_REGISTRY
.
keys
())
f
"Attempted to load model '
{
model_name
}
', but no model for this name found! Supported model names:
{
', '
.
join
(
MODEL_REGISTRY
.
keys
())
}
"
raise
KeyError
(
)
f
"Model '
{
model_name
}
' not found. Available models:
{
available_models
}
"
)
from
err
TASK_REGISTRY
=
{}
TASK_REGISTRY
=
{}
...
@@ -81,7 +84,7 @@ def register_group(name):
...
@@ -81,7 +84,7 @@ def register_group(name):
OUTPUT_TYPE_REGISTRY
=
{}
OUTPUT_TYPE_REGISTRY
=
{}
METRIC_REGISTRY
=
{}
METRIC_REGISTRY
=
{}
METRIC_AGGREGATION_REGISTRY
=
{}
METRIC_AGGREGATION_REGISTRY
=
{}
AGGREGATION_REGISTRY
:
D
ict
[
str
,
Callable
[[],
D
ict
[
str
,
Callable
]]]
=
{}
AGGREGATION_REGISTRY
:
d
ict
[
str
,
Callable
[[],
d
ict
[
str
,
Callable
]]]
=
{}
HIGHER_IS_BETTER_REGISTRY
=
{}
HIGHER_IS_BETTER_REGISTRY
=
{}
FILTER_REGISTRY
=
{}
FILTER_REGISTRY
=
{}
...
@@ -125,7 +128,7 @@ def register_metric(**args):
...
@@ -125,7 +128,7 @@ def register_metric(**args):
return
decorate
return
decorate
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
)
->
Optional
[
Callable
]
:
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
)
->
Callable
[...,
Any
]
|
None
:
if
not
hf_evaluate_metric
:
if
not
hf_evaluate_metric
:
if
name
in
METRIC_REGISTRY
:
if
name
in
METRIC_REGISTRY
:
return
METRIC_REGISTRY
[
name
]
return
METRIC_REGISTRY
[
name
]
...
@@ -157,21 +160,21 @@ def register_aggregation(name: str):
...
@@ -157,21 +160,21 @@ def register_aggregation(name: str):
return
decorate
return
decorate
def
get_aggregation
(
name
:
str
)
->
Optional
[
Callable
[[],
Dict
[
str
,
Callable
]]]
:
def
get_aggregation
(
name
:
str
)
->
Callable
[...,
Any
]
|
None
:
try
:
try
:
return
AGGREGATION_REGISTRY
[
name
]
return
AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
not a registered aggregation metric!"
)
eval_logger
.
warning
(
f
"
{
name
}
not a registered aggregation metric!"
)
def
get_metric_aggregation
(
name
:
str
)
->
Optional
[
Callable
[[],
D
ict
[
str
,
Callable
]]
]
:
def
get_metric_aggregation
(
name
:
str
)
->
Callable
[[],
d
ict
[
str
,
Callable
]]
|
None
:
try
:
try
:
return
METRIC_AGGREGATION_REGISTRY
[
name
]
return
METRIC_AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!"
)
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!"
)
def
is_higher_better
(
metric_name
:
str
)
->
Optional
[
bool
]
:
def
is_higher_better
(
metric_name
:
str
)
->
bool
|
None
:
try
:
try
:
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
except
KeyError
:
except
KeyError
:
...
@@ -192,7 +195,7 @@ def register_filter(name: str):
...
@@ -192,7 +195,7 @@ def register_filter(name: str):
return
decorate
return
decorate
def
get_filter
(
filter_name
:
Union
[
str
,
Callable
]
)
->
Callable
:
def
get_filter
(
filter_name
:
str
|
Callable
)
->
Callable
:
try
:
try
:
return
FILTER_REGISTRY
[
filter_name
]
return
FILTER_REGISTRY
[
filter_name
]
except
KeyError
as
e
:
except
KeyError
as
e
:
...
...
lm_eval/api/task.py
View file @
69d14fb3
This diff is collapsed.
Click to expand it.
lm_eval/config/metric.py
View file @
69d14fb3
from
__future__
import
annotations
from
collections.abc
import
Callable
,
Mapping
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
Any
,
Callable
,
List
,
Optional
from
typing
import
Any
@
dataclass
@
dataclass
...
@@ -8,9 +11,9 @@ class MetricConfig:
...
@@ -8,9 +11,9 @@ class MetricConfig:
"""Encapsulates information about a single metric."""
"""Encapsulates information about a single metric."""
name
:
str
name
:
str
fn
:
Optional
[
Callable
]
=
None
fn
:
Callable
|
None
=
None
kwargs
:
Optional
[
dict
]
=
None
kwargs
:
Mapping
[
str
,
Any
]
|
None
=
None
aggregation_fn
:
Optional
[
Callable
]
=
None
aggregation_fn
:
Callable
|
None
=
None
higher_is_better
:
bool
=
True
higher_is_better
:
bool
=
True
hf_evaluate
:
bool
=
False
hf_evaluate
:
bool
=
False
is_elementwise
:
bool
=
True
is_elementwise
:
bool
=
True
...
@@ -20,7 +23,7 @@ class MetricConfig:
...
@@ -20,7 +23,7 @@ class MetricConfig:
return
self
.
name
return
self
.
name
@
cached_property
@
cached_property
def
aggregation
(
self
)
->
Callable
:
def
aggregation
(
self
)
->
Callable
[...,
Any
]
|
None
:
from
lm_eval.api.registry
import
get_aggregation
from
lm_eval.api.registry
import
get_aggregation
if
self
.
aggregation_fn
is
None
:
if
self
.
aggregation_fn
is
None
:
...
@@ -28,7 +31,7 @@ class MetricConfig:
...
@@ -28,7 +31,7 @@ class MetricConfig:
return
self
.
aggregation_fn
return
self
.
aggregation_fn
@
cached_property
@
cached_property
def
_higher_is_better
(
self
)
->
bool
:
def
_higher_is_better
(
self
)
->
bool
|
None
:
from
lm_eval.api.registry
import
is_higher_better
from
lm_eval.api.registry
import
is_higher_better
if
self
.
higher_is_better
is
None
:
if
self
.
higher_is_better
is
None
:
...
@@ -39,10 +42,10 @@ class MetricConfig:
...
@@ -39,10 +42,10 @@ class MetricConfig:
"""Calculates the metric using the provided function and arguments."""
"""Calculates the metric using the provided function and arguments."""
if
self
.
fn
is
None
:
if
self
.
fn
is
None
:
raise
ValueError
(
f
"Metric function for
{
self
.
name
}
is not defined."
)
raise
ValueError
(
f
"Metric function for
{
self
.
name
}
is not defined."
)
return
self
.
fn
(
*
args
,
**
{
**
self
.
kwargs
,
**
kwargs
})
return
self
.
fn
(
*
args
,
**
{
**
(
self
.
kwargs
or
{})
,
**
kwargs
})
def
compute_aggregation
(
self
,
values
:
List
[
Any
]
)
->
Any
:
def
compute_aggregation
(
self
,
*
args
,
**
kwargs
)
->
Any
:
"""Computes the aggregation of the metric values."""
"""Computes the aggregation of the metric values."""
if
self
.
aggregation_fn
is
None
:
if
self
.
aggregation_fn
is
None
:
raise
ValueError
(
f
"Aggregation function for
{
self
.
name
}
is not defined."
)
raise
ValueError
(
f
"Aggregation function for
{
self
.
name
}
is not defined."
)
return
self
.
aggregation_fn
(
value
s
)
return
self
.
aggregation_fn
(
*
args
,
**
kwarg
s
)
lm_eval/config/task.py
View file @
69d14fb3
from
__future__
import
annotations
import
logging
import
logging
from
collections.abc
import
Iterable
from
dataclasses
import
asdict
,
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Callable
,
Iterable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.instance
import
OutputType
from
lm_eval.api.instance
import
OutputType
...
@@ -20,8 +23,8 @@ class RepeatConfig:
...
@@ -20,8 +23,8 @@ class RepeatConfig:
"""Encapsulates information about a single repeat."""
"""Encapsulates information about a single repeat."""
repeats
:
int
=
1
repeats
:
int
=
1
metric_fn
:
Union
[
str
,
Callable
]
=
"pass@N"
metric_fn
:
str
|
Callable
=
"pass@N"
kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
kwargs
:
dict
|
None
=
field
(
default_factory
=
dict
)
@
dataclass
@
dataclass
...
@@ -38,11 +41,11 @@ class FewshotConfig:
...
@@ -38,11 +41,11 @@ class FewshotConfig:
# hack: this returns task.config.num_fewshot
# hack: this returns task.config.num_fewshot
# to keep in sync as it is runtime-modified
# to keep in sync as it is runtime-modified
num_fewshot
:
Callable
[[],
int
]
num_fewshot
:
Callable
[[],
int
]
split
:
Optional
[
str
]
=
None
split
:
str
|
None
=
None
sampler
:
Union
[
str
,
Callable
]
=
"default"
sampler
:
str
|
Callable
=
"default"
samples
:
Union
[
Callable
[[],
list
[
dict
]]
,
list
[
dict
]
,
None
]
=
None
samples
:
Callable
[[],
list
[
dict
]]
|
list
[
dict
]
|
None
=
None
process_docs
:
Optional
[
Callable
[[
list
[
dict
]],
Iterable
[
dict
]]
]
=
None
process_docs
:
Callable
[[
list
[
dict
]],
Iterable
[
dict
]]
|
None
=
None
fewshot_indices
:
Optional
[
list
[
int
]
]
=
None
fewshot_indices
:
list
[
int
]
|
None
=
None
rnd
:
int
=
field
(
init
=
False
,
default
=
False
)
rnd
:
int
=
field
(
init
=
False
,
default
=
False
)
def
__post_init__
(
self
)
->
None
:
def
__post_init__
(
self
)
->
None
:
...
@@ -65,22 +68,20 @@ class FewshotConfig:
...
@@ -65,22 +68,20 @@ class FewshotConfig:
def
_get_raw_docs
(
def
_get_raw_docs
(
self
,
dataset
self
,
dataset
)
->
Union
[
list
[
dict
]
,
Callable
[[],
Iterable
[
dict
]]
,
None
]
:
)
->
list
[
dict
]
|
Callable
[[],
Iterable
[
dict
]]
|
None
:
"""Get raw documents from configured source."""
"""Get raw documents from configured source."""
if
self
.
split
is
not
None
:
if
self
.
split
is
not
None
:
return
dataset
[
self
.
split
]
return
dataset
[
self
.
split
]
if
self
.
samples
is
not
None
:
if
self
.
samples
is
not
None
:
if
isinstance
(
self
.
samples
,
list
):
if
isinstance
(
self
.
samples
,
list
)
or
callable
(
self
.
samples
):
return
self
.
samples
elif
callable
(
self
.
samples
):
return
self
.
samples
return
self
.
samples
else
:
else
:
raise
TypeError
(
raise
TypeError
(
"samples must be either a list of dicts or a callable returning a list"
"samples must be either a list of dicts or a callable returning a list"
)
)
def
get_docs
(
self
,
dataset
)
->
Optional
[
Iterable
[
dict
]
]
:
def
get_docs
(
self
,
dataset
)
->
Iterable
[
dict
]
|
None
:
"""Get processed documents from configured source."""
"""Get processed documents from configured source."""
raw_docs
=
self
.
_get_raw_docs
(
dataset
)
raw_docs
=
self
.
_get_raw_docs
(
dataset
)
if
raw_docs
is
None
:
if
raw_docs
is
None
:
...
@@ -100,8 +101,8 @@ class FewshotConfig:
...
@@ -100,8 +101,8 @@ class FewshotConfig:
return
self
.
sampler
return
self
.
sampler
def
init_sampler
(
def
init_sampler
(
self
,
docs
:
list
[
dict
],
task
:
"
Task
"
,
rnd
=
None
,
fewshot_indices
=
None
self
,
docs
:
list
[
dict
],
task
:
Task
,
rnd
=
None
,
fewshot_indices
=
None
)
->
"
ContextSampler
"
:
)
->
ContextSampler
:
"""Initialize the sampler with the given documents and task."""
"""Initialize the sampler with the given documents and task."""
if
rnd
is
None
:
if
rnd
is
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -120,49 +121,49 @@ class FewshotConfig:
...
@@ -120,49 +121,49 @@ class FewshotConfig:
@
dataclass
@
dataclass
class
TaskConfig
(
dict
):
class
TaskConfig
(
dict
):
# task naming/registry
# task naming/registry
task
:
Optional
[
str
]
=
None
task
:
str
|
None
=
None
task_alias
:
Optional
[
str
]
=
None
task_alias
:
str
|
None
=
None
tag
:
Optional
[
Union
[
str
,
list
]]
=
None
tag
:
str
|
list
|
None
=
None
# HF dataset options.
# HF dataset options.
# which dataset to use,
# which dataset to use,
# and what splits for what purpose
# and what splits for what purpose
custom_dataset
:
Optional
[
Callable
]
=
None
custom_dataset
:
Callable
|
None
=
None
dataset_path
:
Optional
[
str
]
=
None
dataset_path
:
str
|
None
=
None
dataset_name
:
Optional
[
str
]
=
None
dataset_name
:
str
|
None
=
None
dataset_kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
dataset_kwargs
:
dict
|
None
=
field
(
default_factory
=
dict
)
training_split
:
Optional
[
str
]
=
None
training_split
:
str
|
None
=
None
validation_split
:
Optional
[
str
]
=
None
validation_split
:
str
|
None
=
None
test_split
:
Optional
[
str
]
=
None
test_split
:
str
|
None
=
None
fewshot_split
:
Optional
[
str
]
=
(
fewshot_split
:
str
|
None
=
(
None
# TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
None
# TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
)
)
# formatting / prompting options.
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
# see docs/advanced_task_guide.md for more info
process_docs
:
Optional
[
Callable
]
=
None
process_docs
:
Callable
|
None
=
None
doc_to_text
:
Optional
[
Union
[
Callable
,
str
]]
=
None
doc_to_text
:
Callable
|
str
|
None
=
None
doc_to_target
:
Optional
[
Union
[
Callable
,
str
]]
=
None
doc_to_target
:
Callable
|
str
|
None
=
None
doc_to_image
:
Union
[
Callable
,
str
,
None
]
=
None
doc_to_image
:
Callable
|
str
|
None
=
None
doc_to_audio
:
Union
[
Callable
,
str
,
None
]
=
None
doc_to_audio
:
Callable
|
str
|
None
=
None
unsafe_code
:
bool
=
False
unsafe_code
:
bool
=
False
doc_to_choice
:
Optional
[
Union
[
Callable
,
str
,
dict
,
list
]]
=
None
doc_to_choice
:
Callable
|
str
|
dict
|
list
|
None
=
None
process_results
:
Optional
[
Union
[
Callable
,
str
]]
=
None
process_results
:
Callable
|
str
|
None
=
None
use_prompt
:
Optional
[
str
]
=
None
use_prompt
:
str
|
None
=
None
description
:
str
=
""
description
:
str
=
""
target_delimiter
:
str
=
" "
target_delimiter
:
str
=
" "
fewshot_delimiter
:
str
=
"
\n\n
"
fewshot_delimiter
:
str
=
"
\n\n
"
fewshot_config
:
Optional
[
dict
]
=
None
fewshot_config
:
dict
|
None
=
None
# runtime configuration options
# runtime configuration options
num_fewshot
:
Optional
[
int
]
=
0
num_fewshot
:
int
|
None
=
0
generation_kwargs
:
Optional
[
dict
]
=
None
generation_kwargs
:
dict
|
None
=
None
# scoring options
# scoring options
metric_list
:
Optional
[
list
]
=
None
metric_list
:
list
|
None
=
None
output_type
:
OutputType
=
"generate_until"
output_type
:
OutputType
=
"generate_until"
repeats
:
int
=
1
repeats
:
int
=
1
filter_list
:
Optional
[
list
[
dict
]
]
=
None
filter_list
:
list
[
dict
]
|
None
=
None
should_decontaminate
:
bool
=
False
should_decontaminate
:
bool
=
False
doc_to_decontamination_query
:
Optional
[
str
]
=
None
doc_to_decontamination_query
:
str
|
None
=
None
gen_prefix
:
Optional
[
str
]
=
None
gen_prefix
:
str
|
None
=
None
metadata
:
Optional
[
dict
]
=
field
(
metadata
:
dict
|
None
=
field
(
default_factory
=
dict
default_factory
=
dict
)
# by default, not used in the code. allows for users to pass arbitrary info to tasks
)
# by default, not used in the code. allows for users to pass arbitrary info to tasks
...
@@ -215,9 +216,7 @@ class TaskConfig(dict):
...
@@ -215,9 +216,7 @@ class TaskConfig(dict):
fewshot_indices
=
_fewshot_cfg
.
get
(
"fewshot_indices"
,
None
),
fewshot_indices
=
_fewshot_cfg
.
get
(
"fewshot_indices"
,
None
),
)
)
def
_get_metric
(
def
_get_metric
(
self
,
metric_list
:
list
[
dict
]
|
None
=
None
)
->
list
[
MetricConfig
]:
self
,
metric_list
:
Optional
[
list
[
dict
]]
=
None
)
->
list
[
"MetricConfig"
]:
from
lm_eval.api.registry
import
(
from
lm_eval.api.registry
import
(
AGGREGATION_REGISTRY
,
AGGREGATION_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
...
@@ -314,7 +313,7 @@ class TaskConfig(dict):
...
@@ -314,7 +313,7 @@ class TaskConfig(dict):
return
metrics
return
metrics
@
property
@
property
def
get_filters
(
self
)
->
list
[
"
FilterConfig
"
]:
def
get_filters
(
self
)
->
list
[
FilterConfig
]:
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.filters
import
build_filter_ensemble
if
not
self
.
filter_list
:
if
not
self
.
filter_list
:
...
@@ -354,7 +353,7 @@ class TaskConfig(dict):
...
@@ -354,7 +353,7 @@ class TaskConfig(dict):
return
x
return
x
@
classmethod
@
classmethod
def
from_yaml
(
cls
,
data
:
dict
)
->
"
TaskConfig
"
:
def
from_yaml
(
cls
,
data
:
dict
)
->
TaskConfig
:
"""Create a TaskConfig instance from a YAML-like dictionary."""
"""Create a TaskConfig instance from a YAML-like dictionary."""
return
cls
(
**
data
)
return
cls
(
**
data
)
...
...
lm_eval/config/template.py
View file @
69d14fb3
from
__future__
import
annotations
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -11,19 +13,19 @@ class TemplateConfig:
...
@@ -11,19 +13,19 @@ class TemplateConfig:
"""Encapsulates information about a template."""
"""Encapsulates information about a template."""
template
:
str
template
:
str
doc_to_text
:
Union
[
str
,
Callable
[[
dict
],
str
]
]
doc_to_text
:
str
|
Callable
[[
dict
],
str
]
doc_to_choice
:
Union
[
str
,
list
,
Callable
[[
dict
],
list
]
]
doc_to_choice
:
str
|
list
|
Callable
[[
dict
],
list
]
doc_to_target
:
Union
[
int
,
Callable
[[
dict
],
int
]
]
doc_to_target
:
int
|
Callable
[[
dict
],
int
]
description
:
str
description
:
str
context_prefix
:
str
context_prefix
:
str
prefix_delimiter
:
str
prefix_delimiter
:
str
context_delimiter
:
str
context_delimiter
:
str
answer_suffix
:
str
answer_suffix
:
str
target_delimiter
:
str
target_delimiter
:
str
choice_format
:
Optional
[
str
]
choice_format
:
str
|
None
choice_delimiter
:
Optional
[
str
]
choice_delimiter
:
str
|
None
fewshot_delimiter
:
str
fewshot_delimiter
:
str
metric_list
:
Optional
[
Union
[
list
[
str
]
,
list
[
"
MetricConfig
"
]]]
=
field
(
metric_list
:
list
[
str
]
|
list
[
MetricConfig
]
|
None
=
field
(
default_factory
=
lambda
:
[
"acc"
,
"acc_norm"
]
default_factory
=
lambda
:
[
"acc"
,
"acc_norm"
]
)
)
...
@@ -40,19 +42,19 @@ class MCQTemplateConfig:
...
@@ -40,19 +42,19 @@ class MCQTemplateConfig:
Answer:` doc_to_choice(doc)` for each choice.
Answer:` doc_to_choice(doc)` for each choice.
"""
"""
doc_to_text
:
Union
[
str
,
Callable
[[
dict
],
str
]
]
doc_to_text
:
str
|
Callable
[[
dict
],
str
]
doc_to_choice
:
Union
[
str
,
list
,
Callable
[[
dict
],
list
]
]
doc_to_choice
:
str
|
list
|
Callable
[[
dict
],
list
]
doc_to_target
:
Union
[
int
,
Callable
[[
dict
],
int
]
]
doc_to_target
:
int
|
Callable
[[
dict
],
int
]
template
=
"mcq"
template
=
"mcq"
context_prefix
:
str
=
"Question:"
context_prefix
:
str
=
"Question:"
prefix_delimiter
:
str
=
" "
prefix_delimiter
:
str
=
" "
context_delimiter
:
str
=
"
\n
"
context_delimiter
:
str
=
"
\n
"
answer_suffix
:
str
=
"Answer:"
answer_suffix
:
str
=
"Answer:"
target_delimiter
:
str
=
"
\n
"
target_delimiter
:
str
=
"
\n
"
choice_format
:
Optional
[
str
]
=
"letters"
choice_format
:
str
|
None
=
"letters"
choice_delimiter
:
Optional
[
str
]
=
"
\n
"
choice_delimiter
:
str
|
None
=
"
\n
"
fewshot_delimiter
:
str
=
"
\n\n
"
fewshot_delimiter
:
str
=
"
\n\n
"
metric_list
:
Optional
[
list
[
"
MetricConfig
"
]]
=
field
(
default_factory
=
lambda
:
[
"acc"
])
metric_list
:
list
[
MetricConfig
]
|
None
=
field
(
default_factory
=
lambda
:
[
"acc"
])
@
dataclass
@
dataclass
...
@@ -63,9 +65,9 @@ class ClozeTemplateConfig:
...
@@ -63,9 +65,9 @@ class ClozeTemplateConfig:
Answer:` <doc_to_target(doc)>`
Answer:` <doc_to_target(doc)>`
"""
"""
doc_to_text
:
Union
[
str
,
Callable
[[
dict
],
str
]
]
doc_to_text
:
str
|
Callable
[[
dict
],
str
]
doc_to_choice
:
Union
[
str
,
list
,
Callable
[[
dict
],
list
]
]
doc_to_choice
:
str
|
list
|
Callable
[[
dict
],
list
]
doc_to_target
:
Union
[
int
,
Callable
[[
dict
],
int
]
]
doc_to_target
:
int
|
Callable
[[
dict
],
int
]
template
:
str
=
"cloze"
template
:
str
=
"cloze"
description
:
str
=
""
description
:
str
=
""
context_prefix
:
str
=
"Question:"
context_prefix
:
str
=
"Question:"
...
@@ -73,9 +75,9 @@ class ClozeTemplateConfig:
...
@@ -73,9 +75,9 @@ class ClozeTemplateConfig:
context_delimiter
:
str
=
"
\n
"
context_delimiter
:
str
=
"
\n
"
answer_suffix
:
str
=
"Answer:"
answer_suffix
:
str
=
"Answer:"
target_delimiter
:
str
=
" "
target_delimiter
:
str
=
" "
choice_format
:
Optional
[
str
]
=
None
choice_format
:
str
|
None
=
None
choice_delimiter
:
Optional
[
str
]
=
None
choice_delimiter
:
str
|
None
=
None
fewshot_delimiter
:
str
=
"
\n\n
"
fewshot_delimiter
:
str
=
"
\n\n
"
metric_list
:
Optional
[
list
[
"
MetricConfig
"
]]
=
field
(
metric_list
:
list
[
MetricConfig
]
|
None
=
field
(
default_factory
=
lambda
:
[
"acc"
,
"acc_norm"
]
default_factory
=
lambda
:
[
"acc"
,
"acc_norm"
]
)
)
lm_eval/config/utils.py
View file @
69d14fb3
from
__future__
import
annotations
from
inspect
import
getsource
from
inspect
import
getsource
from
typing
import
Any
,
Callable
,
Union
from
typing
import
Any
,
Callable
def
serialize_callable
(
def
serialize_callable
(
value
:
Union
[
Callable
[...,
Any
]
,
str
]
,
keep_callable
=
False
value
:
Callable
[...,
Any
]
|
str
,
keep_callable
=
False
)
->
Union
[
Callable
[...,
Any
]
,
str
]
:
)
->
Callable
[...,
Any
]
|
str
:
"""Serializes a given function or string.
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
If 'keep_callable' is True, the original callable is returned.
...
@@ -20,9 +22,7 @@ def serialize_callable(
...
@@ -20,9 +22,7 @@ def serialize_callable(
return
str
(
value
)
return
str
(
value
)
def
maybe_serialize
(
def
maybe_serialize
(
val
:
Callable
|
Any
,
keep_callable
=
False
)
->
Callable
|
Any
:
val
:
Union
[
Callable
,
Any
],
keep_callable
=
False
)
->
Union
[
Callable
,
Any
]:
"""Conditionally serializes a value if it is callable."""
"""Conditionally serializes a value if it is callable."""
return
(
return
(
...
...
lm_eval/filters/extraction.py
View file @
69d14fb3
import
re
import
re
import
sys
import
sys
import
unicodedata
import
unicodedata
from
collections.abc
import
Iterable
from
lm_eval.api.filter
import
Filter
from
lm_eval.api.filter
import
Filter
from
lm_eval.api.registry
import
register_filter
from
lm_eval.api.registry
import
register_filter
...
@@ -32,7 +33,9 @@ class RegexFilter(Filter):
...
@@ -32,7 +33,9 @@ class RegexFilter(Filter):
self
.
group_select
=
group_select
self
.
group_select
=
group_select
self
.
fallback
=
fallback
self
.
fallback
=
fallback
def
apply
(
self
,
resps
:
list
[
list
[
str
]],
docs
:
list
[
dict
])
->
list
[
list
[
str
]]:
def
apply
(
self
,
resps
:
Iterable
[
list
[
str
]],
docs
:
Iterable
[
dict
]
)
->
Iterable
[
list
[
str
]]:
# here, we assume we have a list, in which each element is
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# so we process each of these (same input/target response sets)
...
@@ -59,59 +62,13 @@ class RegexFilter(Filter):
...
@@ -59,59 +62,13 @@ class RegexFilter(Filter):
return
filtered_resps
return
filtered_resps
@
register_filter
(
"regex_pos"
)
class
POSFilter
(
Filter
):
""" """
def
__init__
(
self
,
regex_pattern
:
str
=
r
"\['(.*?)'\]"
,
group_select
=
0
,
fallback
=
None
,
**
kwargs
,
)
->
None
:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
super
().
__init__
(
**
kwargs
)
if
fallback
is
None
:
fallback
=
[
"invalid"
]
self
.
regex_pattern
=
regex_pattern
self
.
regex
=
re
.
compile
(
regex_pattern
)
self
.
group_select
=
group_select
self
.
fallback
=
fallback
def
apply
(
self
,
resps
,
docs
):
def
extract_tagged_tokens
(
text
):
# Extract tagged tokens list from text input using regex
tokens
=
re
.
findall
(
r
"\('([^']*)', '([^']*)'\)"
,
text
)
return
[(
token
,
pos
)
for
token
,
pos
in
tokens
]
def
extract_pos_tags
(
result
):
pos_tags
=
[]
if
isinstance
(
result
,
str
):
result
=
extract_tagged_tokens
(
result
)
pos_tags
.
extend
(
pos
for
_
,
pos
in
result
)
return
pos_tags
if
pos_tags
else
self
.
fallback
def
filter_set
(
inst
):
filtered
=
[]
for
resp
in
inst
:
match
=
extract_pos_tags
(
resp
)
filtered
.
append
(
match
)
return
filtered
filtered_resps
=
map
(
lambda
x
:
filter_set
(
x
),
resps
)
return
filtered_resps
@
register_filter
(
"remove_whitespace"
)
@
register_filter
(
"remove_whitespace"
)
class
WhitespaceFilter
(
Filter
):
class
WhitespaceFilter
(
Filter
):
"""Filters out leading whitespace from responses."""
"""Filters out leading whitespace from responses."""
def
apply
(
self
,
resps
:
list
[
list
[
str
]],
docs
:
list
[
dict
])
->
list
[
list
[
str
]]:
def
apply
(
self
,
resps
:
Iterable
[
list
[
str
]],
docs
:
Iterable
[
dict
]
)
->
Iterable
[
list
[
str
]]:
def
filter_set
(
inst
):
def
filter_set
(
inst
):
filtered_resp
=
[]
filtered_resp
=
[]
for
resp
in
inst
:
for
resp
in
inst
:
...
@@ -156,7 +113,9 @@ class MultiChoiceRegexFilter(RegexFilter):
...
@@ -156,7 +113,9 @@ class MultiChoiceRegexFilter(RegexFilter):
self
.
ignore_punctuation
=
ignore_punctuation
self
.
ignore_punctuation
=
ignore_punctuation
self
.
regexes_to_ignore
=
regexes_to_ignore
self
.
regexes_to_ignore
=
regexes_to_ignore
def
apply
(
self
,
resps
:
list
[
list
[
str
]],
docs
:
list
[
dict
])
->
list
[
list
[
str
]]:
def
apply
(
self
,
resps
:
Iterable
[
list
[
str
]],
docs
:
Iterable
[
dict
]
)
->
Iterable
[
list
[
str
]]:
# here, we assume we have a list, in which each element is
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# so we process each of these (same input/target response sets)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment