Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
69d14fb3
Commit
69d14fb3
authored
Jul 21, 2025
by
Baber
Browse files
cleanup
parent
57b8c0b1
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
201 additions
and
259 deletions
+201
-259
lm_eval/api/filter.py
lm_eval/api/filter.py
+6
-4
lm_eval/api/registry.py
lm_eval/api/registry.py
+15
-12
lm_eval/api/task.py
lm_eval/api/task.py
+85
-111
lm_eval/config/metric.py
lm_eval/config/metric.py
+12
-9
lm_eval/config/task.py
lm_eval/config/task.py
+46
-47
lm_eval/config/template.py
lm_eval/config/template.py
+21
-19
lm_eval/config/utils.py
lm_eval/config/utils.py
+6
-6
lm_eval/filters/extraction.py
lm_eval/filters/extraction.py
+10
-51
No files found.
lm_eval/api/filter.py
View file @
69d14fb3
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
typing
import
Iterable
,
List
,
Union
from
lm_eval.api.instance
import
Instance
...
...
@@ -20,7 +20,9 @@ class Filter(ABC):
"""
@
abstractmethod
def
apply
(
self
,
resps
:
Union
[
List
,
Iterable
],
docs
:
List
[
dict
])
->
Iterable
:
def
apply
(
self
,
resps
:
Iterable
[
list
[
str
]],
docs
:
Iterable
[
dict
]
)
->
Iterable
[
list
[
str
]]:
"""
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
...
...
@@ -40,9 +42,9 @@ class FilterEnsemble:
"""
name
:
str
filters
:
L
ist
[
type
[
Filter
]]
filters
:
l
ist
[
type
[
Filter
]]
def
apply
(
self
,
instances
:
L
ist
[
Instance
])
->
None
:
def
apply
(
self
,
instances
:
l
ist
[
Instance
])
->
None
:
resps
,
docs
=
zip
(
*
((
inst
.
resps
,
inst
.
doc
)
for
inst
in
instances
))
resps
,
docs
=
list
(
resps
),
list
(
docs
)
...
...
lm_eval/api/registry.py
View file @
69d14fb3
from
__future__
import
annotations
import
logging
from
typing
import
TYPE_CHECKING
,
Callable
,
Dict
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
if
TYPE_CHECKING
:
...
...
@@ -36,13 +38,14 @@ def register_model(*names):
return
decorate
def
get_model
(
model_name
:
str
)
->
type
[
"
LM
"
]:
def
get_model
(
model_name
:
str
)
->
type
[
LM
]:
try
:
return
MODEL_REGISTRY
[
model_name
]
except
KeyError
:
raise
ValueError
(
f
"Attempted to load model '
{
model_name
}
', but no model for this name found! Supported model names:
{
', '
.
join
(
MODEL_REGISTRY
.
keys
())
}
"
)
except
KeyError
as
err
:
available_models
=
", "
.
join
(
MODEL_REGISTRY
.
keys
())
raise
KeyError
(
f
"Model '
{
model_name
}
' not found. Available models:
{
available_models
}
"
)
from
err
TASK_REGISTRY
=
{}
...
...
@@ -81,7 +84,7 @@ def register_group(name):
OUTPUT_TYPE_REGISTRY
=
{}
METRIC_REGISTRY
=
{}
METRIC_AGGREGATION_REGISTRY
=
{}
AGGREGATION_REGISTRY
:
D
ict
[
str
,
Callable
[[],
D
ict
[
str
,
Callable
]]]
=
{}
AGGREGATION_REGISTRY
:
d
ict
[
str
,
Callable
[[],
d
ict
[
str
,
Callable
]]]
=
{}
HIGHER_IS_BETTER_REGISTRY
=
{}
FILTER_REGISTRY
=
{}
...
...
@@ -125,7 +128,7 @@ def register_metric(**args):
return
decorate
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
)
->
Optional
[
Callable
]
:
def
get_metric
(
name
:
str
,
hf_evaluate_metric
=
False
)
->
Callable
[...,
Any
]
|
None
:
if
not
hf_evaluate_metric
:
if
name
in
METRIC_REGISTRY
:
return
METRIC_REGISTRY
[
name
]
...
...
@@ -157,21 +160,21 @@ def register_aggregation(name: str):
return
decorate
def
get_aggregation
(
name
:
str
)
->
Optional
[
Callable
[[],
Dict
[
str
,
Callable
]]]
:
def
get_aggregation
(
name
:
str
)
->
Callable
[...,
Any
]
|
None
:
try
:
return
AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
not a registered aggregation metric!"
)
def
get_metric_aggregation
(
name
:
str
)
->
Optional
[
Callable
[[],
D
ict
[
str
,
Callable
]]
]
:
def
get_metric_aggregation
(
name
:
str
)
->
Callable
[[],
d
ict
[
str
,
Callable
]]
|
None
:
try
:
return
METRIC_AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
eval_logger
.
warning
(
f
"
{
name
}
metric is not assigned a default aggregation!"
)
def
is_higher_better
(
metric_name
:
str
)
->
Optional
[
bool
]
:
def
is_higher_better
(
metric_name
:
str
)
->
bool
|
None
:
try
:
return
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
except
KeyError
:
...
...
@@ -192,7 +195,7 @@ def register_filter(name: str):
return
decorate
def
get_filter
(
filter_name
:
Union
[
str
,
Callable
]
)
->
Callable
:
def
get_filter
(
filter_name
:
str
|
Callable
)
->
Callable
:
try
:
return
FILTER_REGISTRY
[
filter_name
]
except
KeyError
as
e
:
...
...
lm_eval/api/task.py
View file @
69d14fb3
from
__future__
import
annotations
import
abc
import
ast
import
logging
...
...
@@ -8,15 +10,7 @@ from copy import deepcopy
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Iterable
,
Iterator
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
Union
,
)
import
datasets
...
...
@@ -57,23 +51,23 @@ class Task(abc.ABC):
{"question": ..., question, answer)
"""
VERSION
:
Optional
[
Union
[
int
,
str
]]
=
None
VERSION
:
int
|
str
|
None
=
None
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
# or a path to a custom `datasets` loading script.
DATASET_PATH
:
Optional
[
str
]
=
None
DATASET_PATH
:
str
|
None
=
None
# The name of a subset within `DATASET_PATH`.
DATASET_NAME
:
Optional
[
str
]
=
None
DATASET_NAME
:
str
|
None
=
None
OUTPUT_TYPE
:
Optional
[
OutputType
]
=
None
OUTPUT_TYPE
:
OutputType
|
None
=
None
def
__init__
(
self
,
data_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
download_mode
:
Optional
[
datasets
.
DownloadMode
]
=
None
,
config
:
Optional
[
Mapping
]
=
None
,
# Union[dict, TaskConfig]
data_dir
:
str
|
None
=
None
,
cache_dir
:
str
|
None
=
None
,
download_mode
:
datasets
.
DownloadMode
|
None
=
None
,
config
:
Mapping
|
None
=
None
,
# Union[dict, TaskConfig]
)
->
None
:
"""
:param data_dir: str
...
...
@@ -97,21 +91,21 @@ class Task(abc.ABC):
Fresh download and fresh dataset.
"""
self
.
download
(
data_dir
,
cache_dir
,
download_mode
)
self
.
_training_docs
:
Optional
[
list
]
=
None
self
.
_fewshot_docs
:
Optional
[
list
]
=
None
self
.
_instances
:
Optional
[
L
ist
[
Instance
]
]
=
None
self
.
_training_docs
:
list
|
None
=
None
self
.
_fewshot_docs
:
list
|
None
=
None
self
.
_instances
:
l
ist
[
Instance
]
|
None
=
None
self
.
_config
:
TaskConfig
=
TaskConfig
.
from_yaml
({
**
config
})
self
.
_filters
=
[
build_filter_ensemble
(
"none"
,
[(
"take_first"
,
None
)])]
self
.
fewshot_rnd
:
Optional
[
random
.
Random
]
=
(
self
.
fewshot_rnd
:
random
.
Random
|
None
=
(
None
# purposely induce errors in case of improper usage
)
def
download
(
self
,
data_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
data_dir
:
str
|
None
=
None
,
cache_dir
:
str
|
None
=
None
,
download_mode
=
None
,
)
->
None
:
"""Downloads and returns the task dataset.
...
...
@@ -238,7 +232,7 @@ class Task(abc.ABC):
pass
@
abc
.
abstractmethod
def
doc_to_target
(
self
,
doc
:
dict
)
->
Union
[
str
,
int
]
:
def
doc_to_target
(
self
,
doc
:
dict
)
->
str
|
int
:
pass
# not an abstractmethod because not every language-only task has to implement this
...
...
@@ -254,16 +248,16 @@ class Task(abc.ABC):
def
build_all_requests
(
self
,
*
,
limit
:
Union
[
int
,
None
]
=
None
,
samples
:
Optional
[
L
ist
[
int
]
]
=
None
,
limit
:
int
|
None
=
None
,
samples
:
l
ist
[
int
]
|
None
=
None
,
rank
:
int
=
0
,
world_size
:
int
=
1
,
cache_requests
:
bool
=
False
,
rewrite_requests_cache
:
bool
=
False
,
system_instruction
:
Optional
[
str
]
=
None
,
system_instruction
:
str
|
None
=
None
,
apply_chat_template
:
bool
=
False
,
fewshot_as_multiturn
:
bool
=
False
,
chat_template
:
Optional
[
Callable
]
=
None
,
chat_template
:
Callable
|
None
=
None
,
tokenizer_name
:
str
=
""
,
)
->
None
:
"""Build a set of Instances for a task, and store them in task.instances"""
...
...
@@ -365,7 +359,7 @@ class Task(abc.ABC):
save_to_cache
(
file_name
=
cache_key
,
obj
=
instances
)
@
abc
.
abstractmethod
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
Union
[
list
[
dict
]
,
str
]
,
**
kwargs
):
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
list
[
dict
]
|
str
,
**
kwargs
):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
...
...
@@ -405,7 +399,7 @@ class Task(abc.ABC):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metric scores
"""
pass
return
True
@
deprecated
(
"not used anymore"
)
def
higher_is_better
(
self
):
...
...
@@ -414,7 +408,7 @@ class Task(abc.ABC):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
pass
return
True
def
get_config
(
self
,
key
:
str
)
->
Any
:
return
getattr
(
self
.
_config
,
key
,
None
)
...
...
@@ -488,13 +482,15 @@ class Task(abc.ABC):
example
=
self
.
doc_to_text
(
doc
)
return
description
+
labeled_examples
+
example
def
apply_filters
(
self
)
->
Optional
[
L
ist
[
Instance
]
]
:
def
apply_filters
(
self
)
->
l
ist
[
Instance
]
|
None
:
"""Iterates over FilterEnsembles and applies them to instances"""
if
hasattr
(
self
,
"_filters"
):
if
hasattr
(
self
,
"_filters"
)
and
self
.
_instances
:
for
f
in
self
.
_filters
:
f
.
apply
(
self
.
_instances
)
else
:
eval_logger
.
warning
(
"No filter defined, passing through instances"
)
eval_logger
.
warning
(
"No filter defined or no instances, passing through instances"
)
return
self
.
_instances
def
dump_config
(
self
)
->
dict
:
...
...
@@ -505,9 +501,6 @@ class Task(abc.ABC):
def
set_config
(
self
,
key
:
str
,
value
:
Any
,
update
:
bool
=
False
)
->
None
:
"""Set or update the configuration for a given key."""
if
key
is
None
:
raise
ValueError
(
"Key must be provided."
)
if
update
:
current_value
=
getattr
(
self
.
_config
,
key
,
{})
if
not
isinstance
(
current_value
,
dict
):
...
...
@@ -533,13 +526,13 @@ class Task(abc.ABC):
setattr
(
self
.
_config
,
"metric_list"
,
[
MetricConfig
(
name
=
metric_name
)])
setattr
(
self
.
_config
,
"process_results"
,
lambda
*
args
:
{
"bypass"
:
0
})
def
set_fewshot_seed
(
self
,
seed
:
Optional
[
int
]
=
None
)
->
None
:
def
set_fewshot_seed
(
self
,
seed
:
int
|
None
=
None
)
->
None
:
self
.
fewshot_rnd
=
random
.
Random
(
seed
)
if
hasattr
(
self
,
"sampler"
):
self
.
sampler
.
rnd
=
self
.
fewshot_rnd
@
property
def
eval_docs
(
self
)
->
Union
[
datasets
.
Dataset
,
Iterable
[
dict
]
]
:
def
eval_docs
(
self
)
->
datasets
.
Dataset
|
Iterable
[
dict
]:
if
self
.
has_test_docs
():
return
self
.
test_docs
()
elif
self
.
has_validation_docs
():
...
...
@@ -553,13 +546,13 @@ class Task(abc.ABC):
self
,
*
,
rank
:
int
=
0
,
limit
:
Union
[
int
,
None
]
=
None
,
limit
:
int
|
None
=
None
,
world_size
:
int
=
1
,
samples
:
Optional
[
L
ist
[
int
]
]
=
None
,
)
->
Iterator
[
T
uple
[
int
,
Any
]]:
samples
:
l
ist
[
int
]
|
None
=
None
,
)
->
Iterator
[
t
uple
[
int
,
Any
]]:
if
samples
:
n
=
len
(
self
.
eval_docs
)
assert
all
(
[
e
<
n
for
e
in
samples
]
),
(
assert
all
(
e
<
n
for
e
in
samples
),
(
f
"Elements of --samples should be in the interval [0,k-1] where k is the number of total examples. In this case, k=
{
n
}
."
)
eval_logger
.
info
(
...
...
@@ -592,7 +585,7 @@ class ConfigurableTask(Task):
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
config
:
Optional
[
dict
]
=
None
,
config
:
dict
|
None
=
None
,
)
->
None
:
# Get pre-configured attributes
self
.
_config
=
self
.
CONFIG
...
...
@@ -610,9 +603,8 @@ class ConfigurableTask(Task):
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
)
if
isinstance
(
self
.
config
.
metadata
,
dict
):
if
"version"
in
self
.
config
.
metadata
:
self
.
VERSION
=
self
.
config
.
metadata
[
"version"
]
if
isinstance
(
self
.
config
.
metadata
,
dict
)
and
"version"
in
self
.
config
.
metadata
:
self
.
VERSION
=
self
.
config
.
metadata
[
"version"
]
if
self
.
config
.
output_type
is
not
None
:
if
self
.
config
.
output_type
not
in
ALL_OUTPUT_TYPES
:
...
...
@@ -698,18 +690,13 @@ class ConfigurableTask(Task):
else
:
test_target
=
str
(
test_target
)
if
test_choice
is
not
None
:
check_choices
=
test_choice
else
:
check_choices
=
[
test_target
]
check_choices
=
test_choice
if
test_choice
is
not
None
else
[
test_target
]
if
self
.
config
.
doc_to_choice
is
not
None
:
for
choice
in
check_choices
:
choice_has_whitespace
=
True
if
choice
[
0
].
isspace
()
else
False
choice_has_whitespace
=
choice
[
0
].
isspace
()
delimiter_has_whitespace
=
(
True
if
self
.
config
.
target_delimiter
.
rstrip
()
self
.
config
.
target_delimiter
.
rstrip
()
!=
self
.
config
.
target_delimiter
else
False
)
if
delimiter_has_whitespace
and
choice_has_whitespace
:
...
...
@@ -722,7 +709,7 @@ class ConfigurableTask(Task):
)
def
download
(
self
,
dataset_kwargs
:
Optional
[
D
ict
[
str
,
Any
]
]
=
None
,
**
kwargs
self
,
dataset_kwargs
:
d
ict
[
str
,
Any
]
|
None
=
None
,
**
kwargs
)
->
None
:
from
packaging.version
import
parse
as
vparse
...
...
@@ -746,24 +733,15 @@ class ConfigurableTask(Task):
)
def
has_training_docs
(
self
)
->
bool
:
if
self
.
config
.
training_split
is
not
None
:
return
True
else
:
return
False
return
self
.
config
.
training_split
is
not
None
def
has_validation_docs
(
self
)
->
bool
:
if
self
.
config
.
validation_split
is
not
None
:
return
True
else
:
return
False
return
self
.
config
.
validation_split
is
not
None
def
has_test_docs
(
self
)
->
bool
:
if
self
.
config
.
test_split
is
not
None
:
return
True
else
:
return
False
return
self
.
config
.
test_split
is
not
None
def
training_docs
(
self
)
->
Optional
[
datasets
.
Dataset
]
:
def
training_docs
(
self
)
->
datasets
.
Dataset
|
None
:
if
self
.
has_training_docs
():
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
...
...
@@ -771,7 +749,7 @@ class ConfigurableTask(Task):
)
return
self
.
dataset
[
self
.
config
.
training_split
]
def
validation_docs
(
self
)
->
Optional
[
datasets
.
Dataset
]
:
def
validation_docs
(
self
)
->
datasets
.
Dataset
|
None
:
if
self
.
has_validation_docs
():
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
...
...
@@ -779,7 +757,7 @@ class ConfigurableTask(Task):
)
return
self
.
dataset
[
self
.
config
.
validation_split
]
def
test_docs
(
self
)
->
Optional
[
datasets
.
Dataset
]
:
def
test_docs
(
self
)
->
datasets
.
Dataset
|
None
:
if
self
.
has_test_docs
():
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
self
.
dataset
[
self
.
config
.
test_split
])
...
...
@@ -792,22 +770,25 @@ class ConfigurableTask(Task):
return
docs
# Fallback to parent implementation
if
_num_fewshot
:
=
getattr
(
self
.
config
,
"num_fewshot"
):
if
isinstance
(
_num_fewshot
,
int
)
and
_num_fewshot
>
0
:
eval_logger
.
warning
(
f
"[Task:
{
self
.
config
.
task
}
] "
"num_fewshot > 0 but no fewshot source configured. "
"Using preconfigured rule."
)
if
(
(
_num_fewshot
:
=
self
.
config
.
num_fewshot
)
and
isinstance
(
_num_fewshot
,
int
)
and
_num_fewshot
>
0
):
eval_logger
.
warning
(
f
"[Task:
{
self
.
config
.
task
}
] "
"num_fewshot > 0 but no fewshot source configured. "
"Using preconfigured rule."
)
return
super
().
fewshot_docs
()
@
staticmethod
def
append_target_question
(
labeled_examples
:
L
ist
[
D
ict
[
str
,
str
]],
labeled_examples
:
l
ist
[
d
ict
[
str
,
str
]],
question
:
str
,
fewshot_as_multiturn
:
bool
=
False
,
gen_prefix
:
Optional
[
str
]
=
None
,
gen_prefix
:
str
|
None
=
None
,
)
->
None
:
"""Adds a target question to the labeled examples list.
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
...
...
@@ -831,12 +812,12 @@ class ConfigurableTask(Task):
self
,
doc
:
dict
,
num_fewshot
:
int
,
system_instruction
:
Optional
[
str
]
=
None
,
system_instruction
:
str
|
None
=
None
,
apply_chat_template
:
bool
=
False
,
fewshot_as_multiturn
:
bool
=
False
,
chat_template
:
Optional
[
Callable
]
=
None
,
gen_prefix
:
Optional
[
str
]
=
None
,
)
->
Union
[
str
,
L
ist
[
str
]
,
None
]
:
chat_template
:
Callable
|
None
=
None
,
gen_prefix
:
str
|
None
=
None
,
)
->
str
|
l
ist
[
str
]
|
None
:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
...
...
@@ -857,10 +838,7 @@ class ConfigurableTask(Task):
:returns: str
The fewshot context.
"""
if
apply_chat_template
:
labeled_examples
=
[]
else
:
labeled_examples
=
""
labeled_examples
=
[]
if
apply_chat_template
else
""
# get task description
if
description
:
=
self
.
config
.
description
:
...
...
@@ -930,7 +908,7 @@ class ConfigurableTask(Task):
labeled_examples_list
.
append
(
chat_template
(
chat
,
add_generation_prompt
=
False
if
gen_prefix
else
True
,
add_generation_prompt
=
not
gen_prefix
,
)
)
return
labeled_examples_list
...
...
@@ -954,7 +932,7 @@ class ConfigurableTask(Task):
# return lm.apply_chat_template(labeled_examples)
return
chat_template
(
labeled_examples
,
add_generation_prompt
=
False
if
gen_prefix
else
True
,
add_generation_prompt
=
not
gen_prefix
,
)
else
:
prefix
=
(
...
...
@@ -975,7 +953,7 @@ class ConfigurableTask(Task):
else
:
return
labeled_examples
+
str
(
example
)
+
prefix
def
apply_filters
(
self
)
->
Optional
[
L
ist
[
Instance
]
]
:
def
apply_filters
(
self
)
->
l
ist
[
Instance
]
|
None
:
"""Iterates over FilterEnsembles and applies them to instances"""
if
hasattr
(
self
,
"_filters"
):
for
f
in
self
.
_filters
:
...
...
@@ -1015,9 +993,7 @@ class ConfigurableTask(Task):
"""
return
doc
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
Union
[
int
,
str
,
Callable
,
None
]
=
None
):
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
int
|
str
|
Callable
|
None
=
None
):
# if self.prompt is not None:
# doc_to_text = self.prompt
if
doc_to_text
is
not
None
:
...
...
@@ -1053,9 +1029,7 @@ class ConfigurableTask(Task):
print
(
type
(
doc_to_text
))
raise
TypeError
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
=
None
)
->
Union
[
int
,
str
,
list
[
int
]]:
def
doc_to_target
(
self
,
doc
:
dict
,
doc_to_target
=
None
)
->
int
|
str
|
list
[
int
]:
# if self.prompt is not None:
# doc_to_target = self.prompt
if
doc_to_target
is
not
None
:
...
...
@@ -1104,8 +1078,8 @@ class ConfigurableTask(Task):
def
doc_to_choice
(
self
,
doc
:
dict
,
doc_to_choice
:
Union
[
str
,
list
,
dict
,
Callable
[...,
list
[
str
]]
,
None
]
=
None
,
)
->
L
ist
[
str
]:
doc_to_choice
:
str
|
list
|
dict
|
Callable
[...,
list
[
str
]]
|
None
=
None
,
)
->
l
ist
[
str
]:
# if self.prompt is not None:
# doc_to_choice = self.prompt
if
doc_to_choice
is
not
None
:
...
...
@@ -1132,7 +1106,7 @@ class ConfigurableTask(Task):
else
:
raise
TypeError
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
=
None
)
->
Union
[
int
,
str
,
list
,
None
]
:
def
doc_to_image
(
self
,
doc
:
dict
,
doc_to_image
=
None
)
->
int
|
str
|
list
|
None
:
if
doc_to_image
is
not
None
:
doc_to_image
=
doc_to_image
elif
self
.
config
.
doc_to_image
is
not
None
:
...
...
@@ -1155,7 +1129,7 @@ class ConfigurableTask(Task):
else
:
return
None
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
=
None
)
->
Union
[
int
,
str
,
list
,
None
]
:
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
=
None
)
->
int
|
str
|
list
|
None
:
if
doc_to_audio
is
not
None
:
doc_to_audio
=
doc_to_audio
elif
self
.
config
.
doc_to_audio
is
not
None
:
...
...
@@ -1178,7 +1152,7 @@ class ConfigurableTask(Task):
else
:
return
None
def
doc_to_prefix
(
self
,
doc
:
dict
)
->
Optional
[
str
]
:
def
doc_to_prefix
(
self
,
doc
:
dict
)
->
str
|
None
:
if
(
gen_prefix
:
=
self
.
config
.
gen_prefix
)
is
not
None
:
if
gen_prefix
in
self
.
features
:
return
doc
[
gen_prefix
]
...
...
@@ -1188,7 +1162,7 @@ class ConfigurableTask(Task):
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
str
,
**
kwargs
)
->
Union
[
L
ist
[
Instance
]
,
Instance
]
:
)
->
l
ist
[
Instance
]
|
Instance
:
apply_chat_template
=
kwargs
.
pop
(
"apply_chat_template"
,
False
)
chat_template
:
Callable
|
None
=
kwargs
.
pop
(
"chat_template"
,
None
)
...
...
@@ -1324,7 +1298,7 @@ class ConfigurableTask(Task):
elif
self
.
OUTPUT_TYPE
==
"multiple_choice"
:
lls
,
is_greedy
=
zip
(
*
results
)
# retrieve choices in
L
ist[str] form, to compute choice lengths, etc.
# retrieve choices in
l
ist[str] form, to compute choice lengths, etc.
choices
=
self
.
doc_to_choice
(
doc
)
completion_len
=
np
.
array
([
float
(
len
(
i
))
for
i
in
choices
])
...
...
@@ -1371,7 +1345,7 @@ class ConfigurableTask(Task):
if
self
.
multiple_target
:
acc
=
1.0
if
pred
in
gold
else
0.0
acc_norm
=
1.0
if
pred_norm
in
gold
else
0.0
exact_match
=
int
(
any
(
[
is_greedy
[
i
]
if
i
!=
-
100
else
0
for
i
in
gold
]
))
exact_match
=
int
(
any
(
is_greedy
[
i
]
if
i
!=
-
100
else
0
for
i
in
gold
))
else
:
acc
=
1.0
if
pred
==
gold
else
0.0
acc_norm
=
1.0
if
pred_norm
==
gold
else
0.0
...
...
@@ -1413,7 +1387,7 @@ class ConfigurableTask(Task):
# it assumes that doc_to_target returns a number.
choices
=
self
.
doc_to_choice
(
doc
)
gold
=
choices
[
gold
]
for
metric
in
self
.
_metric_fn_list
.
keys
()
:
for
metric
in
self
.
_metric_fn_list
:
try
:
result_score
=
self
.
_metric_fn_list
[
metric
](
references
=
[
gold
]
if
not
isinstance
(
gold
,
list
)
else
gold
,
...
...
@@ -1447,7 +1421,7 @@ class ConfigurableTask(Task):
return
getattr
(
self
.
_config
,
key
,
None
)
@
property
def
task_name
(
self
)
->
Optional
[
str
]
:
def
task_name
(
self
)
->
str
|
None
:
return
getattr
(
self
.
config
,
"task"
,
None
)
def
__repr__
(
self
):
...
...
@@ -1465,7 +1439,7 @@ class MultipleChoiceTask(Task):
def
doc_to_target
(
self
,
doc
:
dict
)
->
str
:
return
" "
+
doc
[
"choices"
][
doc
[
"gold"
]]
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
str
,
**
kwargs
)
->
L
ist
[
Instance
]:
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
str
,
**
kwargs
)
->
l
ist
[
Instance
]:
# TODO: add mutual info here?
return
[
Instance
(
...
...
@@ -1478,7 +1452,7 @@ class MultipleChoiceTask(Task):
for
i
,
choice
in
enumerate
(
doc
[
"choices"
])
]
def
process_results
(
self
,
doc
:
dict
,
results
:
Iterable
[
T
uple
[
float
,
bool
]])
->
dict
:
def
process_results
(
self
,
doc
:
dict
,
results
:
Iterable
[
t
uple
[
float
,
bool
]])
->
dict
:
results
=
[
res
[
0
]
for
res
in
results
]
# only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
...
...
@@ -1512,7 +1486,7 @@ class PerplexityTask(Task):
def
has_training_docs
(
self
)
->
bool
:
return
False
def
fewshot_examples
(
self
,
k
:
int
,
rnd
)
->
L
ist
:
def
fewshot_examples
(
self
,
k
:
int
,
rnd
)
->
l
ist
:
if
k
!=
0
:
raise
ValueError
(
"The number of fewshot examples must be 0 for perplexity tasks."
...
...
@@ -1543,7 +1517,7 @@ class PerplexityTask(Task):
def
doc_to_target
(
self
,
doc
):
return
doc
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
Optional
[
str
]
,
**
kwargs
):
def
construct_requests
(
self
,
doc
:
dict
,
ctx
:
str
|
None
,
**
kwargs
):
if
bool
(
ctx
):
raise
ValueError
...
...
@@ -1555,7 +1529,7 @@ class PerplexityTask(Task):
**
kwargs
,
)
def
process_results
(
self
,
doc
:
dict
,
results
:
T
uple
[
float
])
->
dict
:
def
process_results
(
self
,
doc
:
dict
,
results
:
t
uple
[
float
])
->
dict
:
(
loglikelihood
,)
=
results
words
=
self
.
count_words
(
self
.
doc_to_target
(
doc
))
bytes_
=
self
.
count_bytes
(
self
.
doc_to_target
(
doc
))
...
...
lm_eval/config/metric.py
View file @
69d14fb3
from
__future__
import
annotations
from
collections.abc
import
Callable
,
Mapping
from
dataclasses
import
dataclass
from
functools
import
cached_property
from
typing
import
Any
,
Callable
,
List
,
Optional
from
typing
import
Any
@
dataclass
...
...
@@ -8,9 +11,9 @@ class MetricConfig:
"""Encapsulates information about a single metric."""
name
:
str
fn
:
Optional
[
Callable
]
=
None
kwargs
:
Optional
[
dict
]
=
None
aggregation_fn
:
Optional
[
Callable
]
=
None
fn
:
Callable
|
None
=
None
kwargs
:
Mapping
[
str
,
Any
]
|
None
=
None
aggregation_fn
:
Callable
|
None
=
None
higher_is_better
:
bool
=
True
hf_evaluate
:
bool
=
False
is_elementwise
:
bool
=
True
...
...
@@ -20,7 +23,7 @@ class MetricConfig:
return
self
.
name
@
cached_property
def
aggregation
(
self
)
->
Callable
:
def
aggregation
(
self
)
->
Callable
[...,
Any
]
|
None
:
from
lm_eval.api.registry
import
get_aggregation
if
self
.
aggregation_fn
is
None
:
...
...
@@ -28,7 +31,7 @@ class MetricConfig:
return
self
.
aggregation_fn
@
cached_property
def
_higher_is_better
(
self
)
->
bool
:
def
_higher_is_better
(
self
)
->
bool
|
None
:
from
lm_eval.api.registry
import
is_higher_better
if
self
.
higher_is_better
is
None
:
...
...
@@ -39,10 +42,10 @@ class MetricConfig:
"""Calculates the metric using the provided function and arguments."""
if
self
.
fn
is
None
:
raise
ValueError
(
f
"Metric function for
{
self
.
name
}
is not defined."
)
return
self
.
fn
(
*
args
,
**
{
**
self
.
kwargs
,
**
kwargs
})
return
self
.
fn
(
*
args
,
**
{
**
(
self
.
kwargs
or
{})
,
**
kwargs
})
def
compute_aggregation
(
self
,
values
:
List
[
Any
]
)
->
Any
:
def
compute_aggregation
(
self
,
*
args
,
**
kwargs
)
->
Any
:
"""Computes the aggregation of the metric values."""
if
self
.
aggregation_fn
is
None
:
raise
ValueError
(
f
"Aggregation function for
{
self
.
name
}
is not defined."
)
return
self
.
aggregation_fn
(
value
s
)
return
self
.
aggregation_fn
(
*
args
,
**
kwarg
s
)
lm_eval/config/task.py
View file @
69d14fb3
from
__future__
import
annotations
import
logging
from
collections.abc
import
Iterable
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Callable
,
Iterable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.instance
import
OutputType
...
...
@@ -20,8 +23,8 @@ class RepeatConfig:
"""Encapsulates information about a single repeat."""
repeats
:
int
=
1
metric_fn
:
Union
[
str
,
Callable
]
=
"pass@N"
kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
metric_fn
:
str
|
Callable
=
"pass@N"
kwargs
:
dict
|
None
=
field
(
default_factory
=
dict
)
@
dataclass
...
...
@@ -38,11 +41,11 @@ class FewshotConfig:
# hack: this returns task.config.num_fewshot
# to keep in sync as it is runtime-modified
num_fewshot
:
Callable
[[],
int
]
split
:
Optional
[
str
]
=
None
sampler
:
Union
[
str
,
Callable
]
=
"default"
samples
:
Union
[
Callable
[[],
list
[
dict
]]
,
list
[
dict
]
,
None
]
=
None
process_docs
:
Optional
[
Callable
[[
list
[
dict
]],
Iterable
[
dict
]]
]
=
None
fewshot_indices
:
Optional
[
list
[
int
]
]
=
None
split
:
str
|
None
=
None
sampler
:
str
|
Callable
=
"default"
samples
:
Callable
[[],
list
[
dict
]]
|
list
[
dict
]
|
None
=
None
process_docs
:
Callable
[[
list
[
dict
]],
Iterable
[
dict
]]
|
None
=
None
fewshot_indices
:
list
[
int
]
|
None
=
None
rnd
:
int
=
field
(
init
=
False
,
default
=
False
)
def
__post_init__
(
self
)
->
None
:
...
...
@@ -65,22 +68,20 @@ class FewshotConfig:
def
_get_raw_docs
(
self
,
dataset
)
->
Union
[
list
[
dict
]
,
Callable
[[],
Iterable
[
dict
]]
,
None
]
:
)
->
list
[
dict
]
|
Callable
[[],
Iterable
[
dict
]]
|
None
:
"""Get raw documents from configured source."""
if
self
.
split
is
not
None
:
return
dataset
[
self
.
split
]
if
self
.
samples
is
not
None
:
if
isinstance
(
self
.
samples
,
list
):
return
self
.
samples
elif
callable
(
self
.
samples
):
if
isinstance
(
self
.
samples
,
list
)
or
callable
(
self
.
samples
):
return
self
.
samples
else
:
raise
TypeError
(
"samples must be either a list of dicts or a callable returning a list"
)
def
get_docs
(
self
,
dataset
)
->
Optional
[
Iterable
[
dict
]
]
:
def
get_docs
(
self
,
dataset
)
->
Iterable
[
dict
]
|
None
:
"""Get processed documents from configured source."""
raw_docs
=
self
.
_get_raw_docs
(
dataset
)
if
raw_docs
is
None
:
...
...
@@ -100,8 +101,8 @@ class FewshotConfig:
return
self
.
sampler
def
init_sampler
(
self
,
docs
:
list
[
dict
],
task
:
"
Task
"
,
rnd
=
None
,
fewshot_indices
=
None
)
->
"
ContextSampler
"
:
self
,
docs
:
list
[
dict
],
task
:
Task
,
rnd
=
None
,
fewshot_indices
=
None
)
->
ContextSampler
:
"""Initialize the sampler with the given documents and task."""
if
rnd
is
None
:
raise
ValueError
(
...
...
@@ -120,49 +121,49 @@ class FewshotConfig:
@
dataclass
class
TaskConfig
(
dict
):
# task naming/registry
task
:
Optional
[
str
]
=
None
task_alias
:
Optional
[
str
]
=
None
tag
:
Optional
[
Union
[
str
,
list
]]
=
None
task
:
str
|
None
=
None
task_alias
:
str
|
None
=
None
tag
:
str
|
list
|
None
=
None
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
custom_dataset
:
Optional
[
Callable
]
=
None
dataset_path
:
Optional
[
str
]
=
None
dataset_name
:
Optional
[
str
]
=
None
dataset_kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
training_split
:
Optional
[
str
]
=
None
validation_split
:
Optional
[
str
]
=
None
test_split
:
Optional
[
str
]
=
None
fewshot_split
:
Optional
[
str
]
=
(
custom_dataset
:
Callable
|
None
=
None
dataset_path
:
str
|
None
=
None
dataset_name
:
str
|
None
=
None
dataset_kwargs
:
dict
|
None
=
field
(
default_factory
=
dict
)
training_split
:
str
|
None
=
None
validation_split
:
str
|
None
=
None
test_split
:
str
|
None
=
None
fewshot_split
:
str
|
None
=
(
None
# TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
)
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
process_docs
:
Optional
[
Callable
]
=
None
doc_to_text
:
Optional
[
Union
[
Callable
,
str
]]
=
None
doc_to_target
:
Optional
[
Union
[
Callable
,
str
]]
=
None
doc_to_image
:
Union
[
Callable
,
str
,
None
]
=
None
doc_to_audio
:
Union
[
Callable
,
str
,
None
]
=
None
process_docs
:
Callable
|
None
=
None
doc_to_text
:
Callable
|
str
|
None
=
None
doc_to_target
:
Callable
|
str
|
None
=
None
doc_to_image
:
Callable
|
str
|
None
=
None
doc_to_audio
:
Callable
|
str
|
None
=
None
unsafe_code
:
bool
=
False
doc_to_choice
:
Optional
[
Union
[
Callable
,
str
,
dict
,
list
]]
=
None
process_results
:
Optional
[
Union
[
Callable
,
str
]]
=
None
use_prompt
:
Optional
[
str
]
=
None
doc_to_choice
:
Callable
|
str
|
dict
|
list
|
None
=
None
process_results
:
Callable
|
str
|
None
=
None
use_prompt
:
str
|
None
=
None
description
:
str
=
""
target_delimiter
:
str
=
" "
fewshot_delimiter
:
str
=
"
\n\n
"
fewshot_config
:
Optional
[
dict
]
=
None
fewshot_config
:
dict
|
None
=
None
# runtime configuration options
num_fewshot
:
Optional
[
int
]
=
0
generation_kwargs
:
Optional
[
dict
]
=
None
num_fewshot
:
int
|
None
=
0
generation_kwargs
:
dict
|
None
=
None
# scoring options
metric_list
:
Optional
[
list
]
=
None
metric_list
:
list
|
None
=
None
output_type
:
OutputType
=
"generate_until"
repeats
:
int
=
1
filter_list
:
Optional
[
list
[
dict
]
]
=
None
filter_list
:
list
[
dict
]
|
None
=
None
should_decontaminate
:
bool
=
False
doc_to_decontamination_query
:
Optional
[
str
]
=
None
gen_prefix
:
Optional
[
str
]
=
None
metadata
:
Optional
[
dict
]
=
field
(
doc_to_decontamination_query
:
str
|
None
=
None
gen_prefix
:
str
|
None
=
None
metadata
:
dict
|
None
=
field
(
default_factory
=
dict
)
# by default, not used in the code. allows for users to pass arbitrary info to tasks
...
...
@@ -215,9 +216,7 @@ class TaskConfig(dict):
fewshot_indices
=
_fewshot_cfg
.
get
(
"fewshot_indices"
,
None
),
)
def
_get_metric
(
self
,
metric_list
:
Optional
[
list
[
dict
]]
=
None
)
->
list
[
"MetricConfig"
]:
def
_get_metric
(
self
,
metric_list
:
list
[
dict
]
|
None
=
None
)
->
list
[
MetricConfig
]:
from
lm_eval.api.registry
import
(
AGGREGATION_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
...
...
@@ -314,7 +313,7 @@ class TaskConfig(dict):
return
metrics
@
property
def
get_filters
(
self
)
->
list
[
"
FilterConfig
"
]:
def
get_filters
(
self
)
->
list
[
FilterConfig
]:
from
lm_eval.filters
import
build_filter_ensemble
if
not
self
.
filter_list
:
...
...
@@ -354,7 +353,7 @@ class TaskConfig(dict):
return
x
@
classmethod
def
from_yaml
(
cls
,
data
:
dict
)
->
"
TaskConfig
"
:
def
from_yaml
(
cls
,
data
:
dict
)
->
TaskConfig
:
"""Create a TaskConfig instance from a YAML-like dictionary."""
return
cls
(
**
data
)
...
...
lm_eval/config/template.py
View file @
69d14fb3
from
__future__
import
annotations
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
if
TYPE_CHECKING
:
...
...
@@ -11,19 +13,19 @@ class TemplateConfig:
"""Encapsulates information about a template."""
template
:
str
doc_to_text
:
Union
[
str
,
Callable
[[
dict
],
str
]
]
doc_to_choice
:
Union
[
str
,
list
,
Callable
[[
dict
],
list
]
]
doc_to_target
:
Union
[
int
,
Callable
[[
dict
],
int
]
]
doc_to_text
:
str
|
Callable
[[
dict
],
str
]
doc_to_choice
:
str
|
list
|
Callable
[[
dict
],
list
]
doc_to_target
:
int
|
Callable
[[
dict
],
int
]
description
:
str
context_prefix
:
str
prefix_delimiter
:
str
context_delimiter
:
str
answer_suffix
:
str
target_delimiter
:
str
choice_format
:
Optional
[
str
]
choice_delimiter
:
Optional
[
str
]
choice_format
:
str
|
None
choice_delimiter
:
str
|
None
fewshot_delimiter
:
str
metric_list
:
Optional
[
Union
[
list
[
str
]
,
list
[
"
MetricConfig
"
]]]
=
field
(
metric_list
:
list
[
str
]
|
list
[
MetricConfig
]
|
None
=
field
(
default_factory
=
lambda
:
[
"acc"
,
"acc_norm"
]
)
...
...
@@ -40,19 +42,19 @@ class MCQTemplateConfig:
Answer:` doc_to_choice(doc)` for each choice.
"""
doc_to_text
:
Union
[
str
,
Callable
[[
dict
],
str
]
]
doc_to_choice
:
Union
[
str
,
list
,
Callable
[[
dict
],
list
]
]
doc_to_target
:
Union
[
int
,
Callable
[[
dict
],
int
]
]
doc_to_text
:
str
|
Callable
[[
dict
],
str
]
doc_to_choice
:
str
|
list
|
Callable
[[
dict
],
list
]
doc_to_target
:
int
|
Callable
[[
dict
],
int
]
template
=
"mcq"
context_prefix
:
str
=
"Question:"
prefix_delimiter
:
str
=
" "
context_delimiter
:
str
=
"
\n
"
answer_suffix
:
str
=
"Answer:"
target_delimiter
:
str
=
"
\n
"
choice_format
:
Optional
[
str
]
=
"letters"
choice_delimiter
:
Optional
[
str
]
=
"
\n
"
choice_format
:
str
|
None
=
"letters"
choice_delimiter
:
str
|
None
=
"
\n
"
fewshot_delimiter
:
str
=
"
\n\n
"
metric_list
:
Optional
[
list
[
"
MetricConfig
"
]]
=
field
(
default_factory
=
lambda
:
[
"acc"
])
metric_list
:
list
[
MetricConfig
]
|
None
=
field
(
default_factory
=
lambda
:
[
"acc"
])
@
dataclass
...
...
@@ -63,9 +65,9 @@ class ClozeTemplateConfig:
Answer:` <doc_to_target(doc)>`
"""
doc_to_text
:
Union
[
str
,
Callable
[[
dict
],
str
]
]
doc_to_choice
:
Union
[
str
,
list
,
Callable
[[
dict
],
list
]
]
doc_to_target
:
Union
[
int
,
Callable
[[
dict
],
int
]
]
doc_to_text
:
str
|
Callable
[[
dict
],
str
]
doc_to_choice
:
str
|
list
|
Callable
[[
dict
],
list
]
doc_to_target
:
int
|
Callable
[[
dict
],
int
]
template
:
str
=
"cloze"
description
:
str
=
""
context_prefix
:
str
=
"Question:"
...
...
@@ -73,9 +75,9 @@ class ClozeTemplateConfig:
context_delimiter
:
str
=
"
\n
"
answer_suffix
:
str
=
"Answer:"
target_delimiter
:
str
=
" "
choice_format
:
Optional
[
str
]
=
None
choice_delimiter
:
Optional
[
str
]
=
None
choice_format
:
str
|
None
=
None
choice_delimiter
:
str
|
None
=
None
fewshot_delimiter
:
str
=
"
\n\n
"
metric_list
:
Optional
[
list
[
"
MetricConfig
"
]]
=
field
(
metric_list
:
list
[
MetricConfig
]
|
None
=
field
(
default_factory
=
lambda
:
[
"acc"
,
"acc_norm"
]
)
lm_eval/config/utils.py
View file @
69d14fb3
from
__future__
import
annotations
from
inspect
import
getsource
from
typing
import
Any
,
Callable
,
Union
from
typing
import
Any
,
Callable
def
serialize_callable
(
value
:
Union
[
Callable
[...,
Any
]
,
str
]
,
keep_callable
=
False
)
->
Union
[
Callable
[...,
Any
]
,
str
]
:
value
:
Callable
[...,
Any
]
|
str
,
keep_callable
=
False
)
->
Callable
[...,
Any
]
|
str
:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
...
...
@@ -20,9 +22,7 @@ def serialize_callable(
return
str
(
value
)
def
maybe_serialize
(
val
:
Union
[
Callable
,
Any
],
keep_callable
=
False
)
->
Union
[
Callable
,
Any
]:
def
maybe_serialize
(
val
:
Callable
|
Any
,
keep_callable
=
False
)
->
Callable
|
Any
:
"""Conditionally serializes a value if it is callable."""
return
(
...
...
lm_eval/filters/extraction.py
View file @
69d14fb3
import
re
import
sys
import
unicodedata
from
collections.abc
import
Iterable
from
lm_eval.api.filter
import
Filter
from
lm_eval.api.registry
import
register_filter
...
...
@@ -32,7 +33,9 @@ class RegexFilter(Filter):
self
.
group_select
=
group_select
self
.
fallback
=
fallback
def
apply
(
self
,
resps
:
list
[
list
[
str
]],
docs
:
list
[
dict
])
->
list
[
list
[
str
]]:
def
apply
(
self
,
resps
:
Iterable
[
list
[
str
]],
docs
:
Iterable
[
dict
]
)
->
Iterable
[
list
[
str
]]:
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
...
...
@@ -59,59 +62,13 @@ class RegexFilter(Filter):
return
filtered_resps
@
register_filter
(
"regex_pos"
)
class
POSFilter
(
Filter
):
""" """
def
__init__
(
self
,
regex_pattern
:
str
=
r
"\['(.*?)'\]"
,
group_select
=
0
,
fallback
=
None
,
**
kwargs
,
)
->
None
:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
super
().
__init__
(
**
kwargs
)
if
fallback
is
None
:
fallback
=
[
"invalid"
]
self
.
regex_pattern
=
regex_pattern
self
.
regex
=
re
.
compile
(
regex_pattern
)
self
.
group_select
=
group_select
self
.
fallback
=
fallback
def
apply
(
self
,
resps
,
docs
):
def
extract_tagged_tokens
(
text
):
# Extract tagged tokens list from text input using regex
tokens
=
re
.
findall
(
r
"\('([^']*)', '([^']*)'\)"
,
text
)
return
[(
token
,
pos
)
for
token
,
pos
in
tokens
]
def
extract_pos_tags
(
result
):
pos_tags
=
[]
if
isinstance
(
result
,
str
):
result
=
extract_tagged_tokens
(
result
)
pos_tags
.
extend
(
pos
for
_
,
pos
in
result
)
return
pos_tags
if
pos_tags
else
self
.
fallback
def
filter_set
(
inst
):
filtered
=
[]
for
resp
in
inst
:
match
=
extract_pos_tags
(
resp
)
filtered
.
append
(
match
)
return
filtered
filtered_resps
=
map
(
lambda
x
:
filter_set
(
x
),
resps
)
return
filtered_resps
@
register_filter
(
"remove_whitespace"
)
class
WhitespaceFilter
(
Filter
):
"""Filters out leading whitespace from responses."""
def
apply
(
self
,
resps
:
list
[
list
[
str
]],
docs
:
list
[
dict
])
->
list
[
list
[
str
]]:
def
apply
(
self
,
resps
:
Iterable
[
list
[
str
]],
docs
:
Iterable
[
dict
]
)
->
Iterable
[
list
[
str
]]:
def
filter_set
(
inst
):
filtered_resp
=
[]
for
resp
in
inst
:
...
...
@@ -156,7 +113,9 @@ class MultiChoiceRegexFilter(RegexFilter):
self
.
ignore_punctuation
=
ignore_punctuation
self
.
regexes_to_ignore
=
regexes_to_ignore
def
apply
(
self
,
resps
:
list
[
list
[
str
]],
docs
:
list
[
dict
])
->
list
[
list
[
str
]]:
def
apply
(
self
,
resps
:
Iterable
[
list
[
str
]],
docs
:
Iterable
[
dict
]
)
->
Iterable
[
list
[
str
]]:
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment