Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
c3dabb32
Commit
c3dabb32
authored
May 16, 2023
by
lintangsutawika
Browse files
updated filter process
parent
4a3c1f19
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
51 additions
and
26 deletions
+51
-26
lm_eval/api/task.py
lm_eval/api/task.py
+26
-6
lm_eval/filters/__init__.py
lm_eval/filters/__init__.py
+7
-7
lm_eval/filters/extraction.py
lm_eval/filters/extraction.py
+4
-5
lm_eval/tasks/gsm8k/base.yaml
lm_eval/tasks/gsm8k/base.yaml
+13
-3
main.py
main.py
+1
-5
No files found.
lm_eval/api/task.py
View file @
c3dabb32
...
@@ -56,7 +56,7 @@ class TaskConfig(dict):
...
@@ -56,7 +56,7 @@ class TaskConfig(dict):
gold_alias
:
str
=
None
gold_alias
:
str
=
None
output_type
:
str
=
"greedy_until"
output_type
:
str
=
"greedy_until"
delimiter
:
str
=
"
\n\n
"
delimiter
:
str
=
"
\n\n
"
filter
s
:
Union
[
str
,
list
]
=
None
filter
_list
:
Union
[
str
,
list
]
=
None
normalization
:
str
=
None
# TODO: add length-normalization of various types, mutual info
normalization
:
str
=
None
# TODO: add length-normalization of various types, mutual info
should_decontaminate
:
bool
=
False
should_decontaminate
:
bool
=
False
doc_to_decontamination_query
:
str
=
None
doc_to_decontamination_query
:
str
=
None
...
@@ -428,7 +428,11 @@ class ConfigurableTask(Task):
...
@@ -428,7 +428,11 @@ class ConfigurableTask(Task):
CONFIG
=
None
CONFIG
=
None
def
__init__
(
def
__init__
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
config
:
dict
=
None
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
config
:
dict
=
None
):
):
# Get pre-configured attributes
# Get pre-configured attributes
self
.
_config
=
self
.
CONFIG
self
.
_config
=
self
.
CONFIG
...
@@ -489,11 +493,27 @@ class ConfigurableTask(Task):
...
@@ -489,11 +493,27 @@ class ConfigurableTask(Task):
self
.
_filters
=
[]
self
.
_filters
=
[]
for
name
,
components
in
self
.
_config
.
get
(
"filters"
,
[[
"none"
,
[
"take_first"
]]]):
filter_pipeline
=
build_filter_ensemble
(
name
,
components
)
if
self
.
_config
.
filter_list
!=
None
:
for
filter_config
in
self
.
_config
.
filter_list
:
for
filter_pipeline
in
filter_config
:
filter_name
=
filter_config
[
"name"
]
filter_functions
=
filter_config
[
"filter"
]
components
=
[]
for
function
in
filter_functions
:
kwargs
=
{
key
:
function
[
key
]
for
key
in
function
if
key
!=
"function"
}
components
.
append
([
function
[
'function'
],
kwargs
])
filter_pipeline
=
build_filter_ensemble
(
filter_name
,
components
)
self
.
_filters
.
append
(
filter_pipeline
)
self
.
_filters
.
append
(
filter_pipeline
)
self
.
sampler
=
samplers
.
Sampler
(
list
(
self
.
fewshot_docs
()),
self
,
rnd
=
random
.
Random
())
# TODO: pass the correct docs in here
if
self
.
fewshot_docs
()
!=
None
:
self
.
sampler
=
samplers
.
Sampler
(
list
(
self
.
fewshot_docs
()),
self
,
rnd
=
random
.
Random
())
# TODO: pass the correct docs in here
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
if
self
.
_config
.
training_split
is
not
None
:
if
self
.
_config
.
training_split
is
not
None
:
...
@@ -526,7 +546,7 @@ class ConfigurableTask(Task):
...
@@ -526,7 +546,7 @@ class ConfigurableTask(Task):
return
self
.
dataset
[
self
.
_config
.
test_split
]
return
self
.
dataset
[
self
.
_config
.
test_split
]
def
fewshot_docs
(
self
):
def
fewshot_docs
(
self
):
if
(
self
.
num_fewshot
>
0
)
and
(
self
.
_config
.
fewshot_split
==
None
):
if
(
self
.
_config
.
num_fewshot
>
0
)
and
(
self
.
_config
.
fewshot_split
==
None
):
eval_logger
.
warning
(
eval_logger
.
warning
(
"num_fewshot > 0 but fewshot_split is None"
,
"num_fewshot > 0 but fewshot_split is None"
,
"using preconfigured rule."
"using preconfigured rule."
...
...
lm_eval/filters/__init__.py
View file @
c3dabb32
...
@@ -17,16 +17,16 @@ def get_filter(filter_name):
...
@@ -17,16 +17,16 @@ def get_filter(filter_name):
return
FILTER_REGISTRY
[
filter_name
]
return
FILTER_REGISTRY
[
filter_name
]
def
build_filter_ensemble
(
name
,
components
):
def
build_filter_ensemble
(
filter_
name
,
components
):
"""
"""
Create a filtering pipeline.
Create a filtering pipeline.
"""
"""
filters
=
[]
filters
=
[]
for
step
in
components
:
for
(
function
,
kwargs
)
in
components
:
# create a filter given its name in the registry
# create a filter given its name in the registry
f
=
get_filter
(
step
)(
)
# TODO: pass kwargs to filters properly
f
=
get_filter
(
function
)(
**
kwargs
)
# TODO: pass kwargs to filters properly
# add the filter as a pipeline step
# add the filter as a pipeline step
filters
.
append
(
f
)
filters
.
append
(
f
)
return
FilterEnsemble
(
name
=
name
,
filters
=
filters
)
return
FilterEnsemble
(
name
=
filter_
name
,
filters
=
filters
)
lm_eval/filters/extraction.py
View file @
c3dabb32
...
@@ -9,14 +9,13 @@ class RegexFilter(Filter):
...
@@ -9,14 +9,13 @@ class RegexFilter(Filter):
"""
"""
def
__init__
(
self
,
regex
=
r
"#### (\-?[0-9\.\,]+)"
,
fallback
=
"[invalid]"
):
def
__init__
(
self
,
regex
_pattern
=
r
"#### (\-?[0-9\.\,]+)"
,
fallback
=
"[invalid]"
):
"""
"""
pass a string `regex` to run `re.compile(r"regex")` on.
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
`fallback` defines the output returned if no matches for the regex are located.
"""
"""
self
.
regex_pattern
=
regex
self
.
regex_pattern
=
regex_pattern
self
.
regex
=
re
.
compile
(
regex
)
self
.
regex
=
re
.
compile
(
regex_pattern
)
self
.
fallback
=
fallback
self
.
fallback
=
fallback
def
apply
(
self
,
resps
):
def
apply
(
self
,
resps
):
...
@@ -30,7 +29,7 @@ class RegexFilter(Filter):
...
@@ -30,7 +29,7 @@ class RegexFilter(Filter):
match
=
self
.
regex
.
search
(
resp
)
match
=
self
.
regex
.
search
(
resp
)
if
match
:
if
match
:
match
=
match
.
group
(
1
).
strip
()
match
=
match
.
group
(
1
).
strip
()
match
_str
.
replace
(
","
,
""
)
match
.
replace
(
","
,
""
)
# TODO: should we assume any other filtering is performed?
# TODO: should we assume any other filtering is performed?
else
:
else
:
match
=
self
.
fallback
match
=
self
.
fallback
...
...
lm_eval/tasks/gsm8k/base.yaml
View file @
c3dabb32
...
@@ -40,6 +40,16 @@ metric_list:
...
@@ -40,6 +40,16 @@ metric_list:
ignore_case
:
true
ignore_case
:
true
ignore_punctuation
:
true
ignore_punctuation
:
true
delimiter
:
"
\n
"
delimiter
:
"
\n
"
# filters: [
filter_list
:
# ["regex", ["regex", "take_first"]]
-
name
:
"
just
regex"
# ]
filter
:
\ No newline at end of file
-
function
:
"
regex"
regex_pattern
:
"
.*"
-
function
:
"
regex"
regex_pattern
:
"
.*"
-
name
:
"
another
regex"
filter
:
-
function
:
"
regex"
regex_pattern
:
"
.*"
-
function
:
"
regex"
regex_pattern
:
"
.*"
\ No newline at end of file
main.py
View file @
c3dabb32
import
os
import
os
import
yaml
import
json
import
json
import
fnmatch
import
fnmatch
import
warnings
import
argparse
import
argparse
from
pprint
import
pformat
from
lm_eval
import
evaluator
,
utils
from
lm_eval
import
evaluator
,
utils
from
lm_eval.tasks
import
ALL_TASKS
from
lm_eval.tasks
import
ALL_TASKS
from
lm_eval.logger
import
eval_logger
from
lm_eval.logger
import
eval_logger
...
@@ -22,7 +18,7 @@ class MultiChoice:
...
@@ -22,7 +18,7 @@ class MultiChoice:
for
value
in
values
.
split
(
","
):
for
value
in
values
.
split
(
","
):
if
len
(
fnmatch
.
filter
(
self
.
choices
,
value
))
==
0
:
if
len
(
fnmatch
.
filter
(
self
.
choices
,
value
))
==
0
:
eval_logger
.
warning
(
"{} is not in task list."
.
format
(
value
))
eval_logger
.
warning
(
"{} is not in task list."
.
format
(
value
))
# eval_logger.info(f"{
ALL_TASKS
} is this")
# eval_logger.info(f"{
choices
} is this")
return
True
return
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment