Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
d27c0c08
Unverified
Commit
d27c0c08
authored
Feb 26, 2024
by
LSinev
Committed by
GitHub
Feb 26, 2024
Browse files
Apply code autoformatting with Ruff to tasks/*.py an *__init__.py (#1469)
parent
f78e2da4
Changes
48
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
252 additions
and
150 deletions
+252
-150
lm_eval/filters/__init__.py
lm_eval/filters/__init__.py
+3
-4
lm_eval/models/__init__.py
lm_eval/models/__init__.py
+14
-10
lm_eval/prompts/__init__.py
lm_eval/prompts/__init__.py
+3
-2
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+75
-57
lm_eval/tasks/bbh/_generate_configs.py
lm_eval/tasks/bbh/_generate_configs.py
+3
-3
lm_eval/tasks/bbh/cot_zeroshot/utils.py
lm_eval/tasks/bbh/cot_zeroshot/utils.py
+58
-25
lm_eval/tasks/bbh/zeroshot/utils.py
lm_eval/tasks/bbh/zeroshot/utils.py
+58
-25
lm_eval/tasks/belebele/_generate_configs.py
lm_eval/tasks/belebele/_generate_configs.py
+6
-4
lm_eval/tasks/bigbench/generate_tasks.py
lm_eval/tasks/bigbench/generate_tasks.py
+2
-0
lm_eval/tasks/bigbench/push_bigbench_dataset.py
lm_eval/tasks/bigbench/push_bigbench_dataset.py
+2
-3
lm_eval/tasks/blimp/generate_configs.py
lm_eval/tasks/blimp/generate_configs.py
+1
-0
lm_eval/tasks/ceval/_generate_configs.py
lm_eval/tasks/ceval/_generate_configs.py
+3
-2
lm_eval/tasks/cmmlu/_generate_configs.py
lm_eval/tasks/cmmlu/_generate_configs.py
+3
-2
lm_eval/tasks/code_x_glue/code-text/bleu.py
lm_eval/tasks/code_x_glue/code-text/bleu.py
+3
-3
lm_eval/tasks/csatqa/_generate_configs.py
lm_eval/tasks/csatqa/_generate_configs.py
+3
-2
lm_eval/tasks/drop/utils.py
lm_eval/tasks/drop/utils.py
+1
-0
lm_eval/tasks/gpqa/n_shot/_generate_configs.py
lm_eval/tasks/gpqa/n_shot/_generate_configs.py
+1
-1
lm_eval/tasks/gpqa/n_shot/utils.py
lm_eval/tasks/gpqa/n_shot/utils.py
+7
-3
lm_eval/tasks/gpqa/zeroshot/_generate_configs.py
lm_eval/tasks/gpqa/zeroshot/_generate_configs.py
+1
-1
lm_eval/tasks/gpqa/zeroshot/utils.py
lm_eval/tasks/gpqa/zeroshot/utils.py
+5
-3
No files found.
lm_eval/filters/__init__.py
View file @
d27c0c08
from
typing
import
List
,
Union
from
functools
import
partial
from
typing
import
List
,
Union
from
lm_eval.api.filter
import
FilterEnsemble
from
.
import
selection
from
.
import
extraction
from
.
import
transformation
from
.
import
extraction
,
selection
,
transformation
FILTER_REGISTRY
=
{
...
...
lm_eval/models/__init__.py
View file @
d27c0c08
from
.
import
huggingface
from
.
import
openai_completions
from
.
import
textsynth
from
.
import
dummy
from
.
import
anthropic_llms
from
.
import
gguf
from
.
import
vllm_causallms
from
.
import
mamba_lm
from
.
import
optimum_lm
from
.
import
neuron_optimum
from
.
import
(
anthropic_llms
,
dummy
,
gguf
,
huggingface
,
mamba_lm
,
neuron_optimum
,
openai_completions
,
optimum_lm
,
textsynth
,
vllm_causallms
,
)
# TODO: implement __all__
...
...
lm_eval/prompts/__init__.py
View file @
d27c0c08
import
os
import
ast
import
os
from
typing
import
Dict
from
lm_eval
import
utils
from
lm_eval.utils
import
eval_logger
# Prompt library.
# Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name.
...
...
lm_eval/tasks/__init__.py
View file @
d27c0c08
import
os
import
abc
import
collections
import
logging
import
os
from
functools
import
partial
from
typing
import
List
,
Union
,
Dict
from
typing
import
Dict
,
List
,
Union
from
lm_eval
import
utils
from
lm_eval.api.task
import
Task
,
ConfigurableTask
import
logging
from
lm_eval.api.task
import
ConfigurableTask
,
Task
class
TaskManager
:
...
...
@@ -16,20 +14,14 @@ class TaskManager:
and an optional directory if provided.
"""
def
__init__
(
self
,
verbosity
=
"INFO"
,
include_path
=
None
)
->
None
:
def
__init__
(
self
,
verbosity
=
"INFO"
,
include_path
=
None
)
->
None
:
self
.
verbosity
=
verbosity
self
.
include_path
=
include_path
self
.
logger
=
utils
.
eval_logger
self
.
logger
.
setLevel
(
getattr
(
logging
,
f
"
{
verbosity
}
"
))
self
.
_task_index
=
self
.
initialize_tasks
(
include_path
=
include_path
)
self
.
_task_index
=
self
.
initialize_tasks
(
include_path
=
include_path
)
self
.
_all_tasks
=
sorted
(
list
(
self
.
_task_index
.
keys
()))
self
.
task_group_map
=
collections
.
defaultdict
(
list
)
...
...
@@ -65,27 +57,29 @@ class TaskManager:
return
self
.
_task_index
def
match_tasks
(
self
,
task_list
):
return
utils
.
pattern_match
(
task_list
,
self
.
all_tasks
)
return
utils
.
pattern_match
(
task_list
,
self
.
all_tasks
)
def
_name_is_registered
(
self
,
name
):
if
name
in
self
.
all_tasks
:
return
True
return
False
def
_name_is_task
(
self
,
name
):
def
_name_is_task
(
self
,
name
)
->
bool
:
if
self
.
_name_is_registered
(
name
)
and
(
"task"
in
self
.
task_index
[
name
][
"type"
]):
return
True
return
False
def
_name_is_group
(
self
,
name
):
if
self
.
_name_is_registered
(
name
)
and
(
self
.
task_index
[
name
][
"type"
]
==
"group"
):
if
self
.
_name_is_registered
(
name
)
and
(
self
.
task_index
[
name
][
"type"
]
==
"group"
):
return
True
return
False
def
_name_is_python_task
(
self
,
name
):
if
self
.
_name_is_registered
(
name
)
and
(
self
.
task_index
[
name
][
"type"
]
==
"python_task"
):
if
self
.
_name_is_registered
(
name
)
and
(
self
.
task_index
[
name
][
"type"
]
==
"python_task"
):
return
True
return
False
...
...
@@ -117,7 +111,7 @@ class TaskManager:
return
utils
.
load_yaml_config
(
yaml_path
,
mode
=
"full"
)
def
_get_tasklist
(
self
,
name
):
assert
self
.
_name_is_task
(
name
)
==
False
assert
self
.
_name_is_task
(
name
)
is
False
return
self
.
task_index
[
name
][
"task"
]
def
_process_alias
(
self
,
config
,
group
=
None
):
...
...
@@ -130,12 +124,12 @@ class TaskManager:
return
config
def
_load_individual_task_or_group
(
self
,
name_or_config
:
Union
[
str
,
dict
]
=
None
,
parent_name
:
str
=
None
,
update_config
:
dict
=
None
,
yaml_path
:
str
=
None
,
)
->
ConfigurableTask
:
self
,
name_or_config
:
Union
[
str
,
dict
]
=
None
,
parent_name
:
str
=
None
,
update_config
:
dict
=
None
,
yaml_path
:
str
=
None
,
)
->
ConfigurableTask
:
def
load_task
(
config
,
task
,
group
=
None
,
yaml_path
=
None
):
if
"include"
in
config
:
assert
yaml_path
is
not
None
...
...
@@ -174,7 +168,9 @@ class TaskManager:
group_config
=
self
.
_get_config
(
name_or_config
)
if
set
(
group_config
.
keys
())
>
set
([
"task"
,
"group"
]):
update_config
=
{
k
:
v
for
k
,
v
in
group_config
.
items
()
if
k
not
in
[
"task"
,
"group"
]
k
:
v
for
k
,
v
in
group_config
.
items
()
if
k
not
in
[
"task"
,
"group"
]
}
yaml_path
=
self
.
_get_yaml_path
(
group_name
)
...
...
@@ -183,9 +179,8 @@ class TaskManager:
update_config
.
pop
(
"group_alias"
)
if
isinstance
(
name_or_config
,
dict
):
if
update_config
is
not
None
:
name_or_config
=
{
name_or_config
=
{
**
name_or_config
,
**
update_config
,
}
...
...
@@ -196,7 +191,9 @@ class TaskManager:
# if self._name_is_task(name) is False:
if
self
.
_name_is_group
(
name
):
group_name
=
name
update_config
=
{
k
:
v
for
k
,
v
in
name_or_config
.
items
()
if
k
!=
"task"
}
update_config
=
{
k
:
v
for
k
,
v
in
name_or_config
.
items
()
if
k
!=
"task"
}
subtask_list
=
self
.
_get_tasklist
(
name
)
if
subtask_list
==
-
1
:
subtask_list
=
self
.
_get_config
(
name
)[
"task"
]
...
...
@@ -207,36 +204,53 @@ class TaskManager:
# Check if this is a duplicate.
if
parent_name
is
not
None
:
name_or_config
[
"group"
]
=
parent_name
num_duplicate
=
len
(
list
(
filter
(
lambda
x
:
x
.
startswith
(
name
),
self
.
task_group_map
[
parent_name
])))
num_duplicate
=
len
(
list
(
filter
(
lambda
x
:
x
.
startswith
(
name
),
self
.
task_group_map
[
parent_name
],
)
)
)
if
num_duplicate
>
0
:
name
=
f
"
{
name
}
-
{
num_duplicate
}
"
self
.
task_group_map
[
parent_name
].
append
(
name
)
task_config
=
{
**
base_task_config
,
**
name_or_config
,
}
task_config
=
{
**
base_task_config
,
**
name_or_config
,
}
else
:
task_config
=
name_or_config
return
load_task
(
task_config
,
task
=
name
,
group
=
parent_name
,
yaml_path
=
yaml_path
)
return
load_task
(
task_config
,
task
=
name
,
group
=
parent_name
,
yaml_path
=
yaml_path
)
else
:
group_name
=
name_or_config
[
"group"
]
subtask_list
=
name_or_config
[
"task"
]
# update_config = {k:v for k,v in name_or_config.items() if k != "task"}
if
set
(
name_or_config
.
keys
())
>
set
([
"task"
,
"group"
]):
update_config
=
{
k
:
v
for
k
,
v
in
name_or_config
.
items
()
if
k
not
in
[
"task"
,
"group"
]
k
:
v
for
k
,
v
in
name_or_config
.
items
()
if
k
not
in
[
"task"
,
"group"
]
}
all_subtasks
=
{}
if
(
parent_name
is
not
None
)
:
if
parent_name
is
not
None
:
all_subtasks
=
{
group_name
:
(
parent_name
,
None
)}
fn
=
partial
(
self
.
_load_individual_task_or_group
,
parent_name
=
group_name
,
update_config
=
update_config
,
yaml_path
=
yaml_path
)
all_subtasks
=
{
**
all_subtasks
,
**
dict
(
collections
.
ChainMap
(
*
map
(
fn
,
subtask_list
)))}
fn
=
partial
(
self
.
_load_individual_task_or_group
,
parent_name
=
group_name
,
update_config
=
update_config
,
yaml_path
=
yaml_path
,
)
all_subtasks
=
{
**
all_subtasks
,
**
dict
(
collections
.
ChainMap
(
*
map
(
fn
,
subtask_list
))),
}
return
all_subtasks
def
load_task_or_group
(
self
,
task_list
:
Union
[
str
,
list
]
=
None
)
->
dict
:
"""Loads a dictionary of task objects from a list
...
...
@@ -250,12 +264,7 @@ class TaskManager:
task_list
=
[
task_list
]
all_loaded_tasks
=
dict
(
collections
.
ChainMap
(
*
map
(
self
.
_load_individual_task_or_group
,
task_list
)
)
collections
.
ChainMap
(
*
map
(
self
.
_load_individual_task_or_group
,
task_list
))
)
return
all_loaded_tasks
...
...
@@ -299,11 +308,11 @@ class TaskManager:
# This is a group config
tasks_and_groups
[
config
[
"group"
]]
=
{
"type"
:
"group"
,
"task"
:
-
1
,
# This signals that
# we don't need to know
# the task list for indexing
# as it can be loaded
# when called.
"task"
:
-
1
,
# This signals that
# we don't need to know
# the task list for indexing
# as it can be loaded
# when called.
"yaml_path"
:
yaml_path
,
}
...
...
@@ -322,7 +331,7 @@ class TaskManager:
tasks_and_groups
[
task
]
=
{
"type"
:
"task"
,
"yaml_path"
:
yaml_path
,
}
}
if
"group"
in
config
:
groups
=
config
[
"group"
]
...
...
@@ -343,6 +352,7 @@ class TaskManager:
return
tasks_and_groups
def
include_path
(
task_dir
):
logger
=
utils
.
eval_logger
logger
.
setLevel
(
getattr
(
logging
,
"INFO"
))
...
...
@@ -352,6 +362,7 @@ def include_path(task_dir):
)
return
0
def
initialize_tasks
(
verbosity
=
"INFO"
):
logger
=
utils
.
eval_logger
logger
.
setLevel
(
getattr
(
logging
,
f
"
{
verbosity
}
"
))
...
...
@@ -362,6 +373,7 @@ def initialize_tasks(verbosity="INFO"):
)
return
0
def
get_task_name_from_config
(
task_config
:
Dict
[
str
,
str
])
->
str
:
if
"task"
in
task_config
:
return
task_config
[
"task"
]
...
...
@@ -370,6 +382,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str:
else
:
return
"{dataset_path}"
.
format
(
**
task_config
)
def
get_task_name_from_object
(
task_object
):
if
hasattr
(
task_object
,
"config"
):
return
task_object
.
_config
[
"task"
]
...
...
@@ -382,7 +395,10 @@ def get_task_name_from_object(task_object):
else
type
(
task_object
).
__name__
)
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
Dict
,
Task
]],
task_manager
:
TaskManager
=
None
):
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
Dict
,
Task
]],
task_manager
:
TaskManager
=
None
):
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
:param task_name_list: List[Union[str, Dict, Task]]
...
...
@@ -409,7 +425,9 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], task_manager: Ta
if
task_manager
is
None
:
task_manager
=
TaskManager
()
task_name_from_string_dict
=
task_manager
.
load_task_or_group
(
string_task_name_list
)
task_name_from_string_dict
=
task_manager
.
load_task_or_group
(
string_task_name_list
)
for
task_element
in
others_task_name_list
:
if
isinstance
(
task_element
,
dict
):
...
...
lm_eval/tasks/bbh/_generate_configs.py
View file @
d27c0c08
"""
Take in a YAML, and output all other splits with this YAML
"""
import
argparse
import
os
import
re
import
yaml
import
requests
import
argparse
import
datasets
import
requests
import
yaml
from
tqdm
import
tqdm
from
lm_eval
import
utils
...
...
lm_eval/tasks/bbh/cot_zeroshot/utils.py
View file @
d27c0c08
import
collections
import
re
import
sys
import
unicodedata
from
lm_eval.filters.extraction
import
Regex
Filter
,
Filter
from
lm_eval.filters.extraction
import
Filter
,
Regex
Filter
class
ExtendedRegexFilter
(
RegexFilter
):
punct_tbl
=
dict
.
fromkeys
(
i
for
i
in
range
(
sys
.
maxunicode
)
if
unicodedata
.
category
(
chr
(
i
)).
startswith
(
'P'
))
punct_tbl
=
dict
.
fromkeys
(
i
for
i
in
range
(
sys
.
maxunicode
)
if
unicodedata
.
category
(
chr
(
i
)).
startswith
(
"P"
)
)
def
__init__
(
self
,
regex_pattern
:
str
=
r
"#### (\-?[0-9\.\,]+)"
,
group_select
=
0
,
fallback
:
str
=
"[invalid]"
,
ignore_case
=
False
,
ignore_punctuation
=
False
,
regexes_to_ignore
=
None
,
self
,
regex_pattern
:
str
=
r
"#### (\-?[0-9\.\,]+)"
,
group_select
=
0
,
fallback
:
str
=
"[invalid]"
,
ignore_case
=
False
,
ignore_punctuation
=
False
,
regexes_to_ignore
=
None
,
)
->
None
:
super
().
__init__
(
regex_pattern
,
group_select
,
fallback
)
self
.
ignore_case
=
ignore_case
...
...
@@ -47,8 +52,13 @@ class ExtendedRegexFilter(RegexFilter):
class
MapRegexFilter
(
ExtendedRegexFilter
):
def
__init__
(
self
,
regex_pattern_to_value
:
dict
=
{},
group_select
=
0
,
fallback
:
str
=
"[invalid]"
,
ignore_case
=
False
,
ignore_punctuation
=
False
,
regexes_to_ignore
=
None
,
self
,
regex_pattern_to_value
:
dict
=
{},
group_select
=
0
,
fallback
:
str
=
"[invalid]"
,
ignore_case
=
False
,
ignore_punctuation
=
False
,
regexes_to_ignore
=
None
,
)
->
None
:
"""
regex_pattern_to_value: Match the regex pattern and change the result into the value
...
...
@@ -57,8 +67,17 @@ class MapRegexFilter(ExtendedRegexFilter):
ignore_punctuation: Remove the punctuation before matching with the given regex
regexes_to_ignore: Remove these regexes before matching with the given regex
"""
super
().
__init__
(
'|'
.
join
(
list
(
regex_pattern_to_value
.
keys
())),
group_select
,
fallback
,
ignore_case
,
ignore_punctuation
,
regexes_to_ignore
)
self
.
regex_to_value
=
{
re
.
compile
(
r
):
v
for
r
,
v
in
regex_pattern_to_value
.
items
()}
super
().
__init__
(
"|"
.
join
(
list
(
regex_pattern_to_value
.
keys
())),
group_select
,
fallback
,
ignore_case
,
ignore_punctuation
,
regexes_to_ignore
,
)
self
.
regex_to_value
=
{
re
.
compile
(
r
):
v
for
r
,
v
in
regex_pattern_to_value
.
items
()
}
def
apply
(
self
,
resps
,
docs
):
filtered_resps
=
[]
...
...
@@ -66,10 +85,15 @@ class MapRegexFilter(ExtendedRegexFilter):
for
r
in
resps
:
filtered
=
[]
for
resp
in
r
:
whole_match_considering_group_select
=
self
.
find_match
(
self
.
regex
,
self
.
filter_ignores
(
resp
))
whole_match_considering_group_select
=
self
.
find_match
(
self
.
regex
,
self
.
filter_ignores
(
resp
)
)
if
whole_match_considering_group_select
:
for
regex
,
mapped_value
in
self
.
regex_to_value
.
items
():
match
=
self
.
find_match
(
regex
,
self
.
filter_ignores
(
whole_match_considering_group_select
))
match
=
self
.
find_match
(
regex
,
self
.
filter_ignores
(
whole_match_considering_group_select
),
)
if
match
:
match
=
mapped_value
break
...
...
@@ -91,9 +115,11 @@ class NumberParseRegexFilter(ExtendedRegexFilter):
filtered_resps
=
[]
import
regex
from
word2number
import
w2n
# https://www.reddit.com/r/regex/comments/11a38uk/parsing_numbers_written_out_as_english_words
english_number_regex
=
regex
.
compile
(
"((?:(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?:|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion)(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?:|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion|[^\S
\r\n
]|,|and|&)+)?(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion))"
)
"((?:(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?:|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion)(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?:|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion|[^\S
\r\n
]|,|and|&)+)?(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion))"
)
for
r
in
resps
:
filtered
=
[]
...
...
@@ -118,21 +144,22 @@ class WordSortFilter(Filter):
filtered_resps
=
[]
for
r
,
doc
in
zip
(
resps
,
docs
):
words
=
doc
[
'
input
'
].
split
(
"List:"
)[
1
].
strip
().
split
()
regex
=
re
.
compile
(
'|'
.
join
([
f
"
\\
b
{
w
}
\\
b"
for
w
in
words
]))
words
=
doc
[
"
input
"
].
split
(
"List:"
)[
1
].
strip
().
split
()
regex
=
re
.
compile
(
"|"
.
join
([
f
"
\\
b
{
w
}
\\
b"
for
w
in
words
]))
filtered
=
[]
for
resp
in
r
:
match
=
regex
.
findall
(
resp
)
match
.
reverse
()
ordered_words
=
reversed
(
collections
.
OrderedDict
(
zip
(
match
,
[
None
]
*
len
(
match
))))
filtered
.
append
(
' '
.
join
(
ordered_words
))
ordered_words
=
reversed
(
collections
.
OrderedDict
(
zip
(
match
,
[
None
]
*
len
(
match
)))
)
filtered
.
append
(
" "
.
join
(
ordered_words
))
filtered_resps
.
append
(
filtered
)
return
filtered_resps
class
MultiChoiceRegexFilter
(
ExtendedRegexFilter
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
"""
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
...
...
@@ -156,13 +183,13 @@ class MultiChoiceRegexFilter(ExtendedRegexFilter):
for
r
,
doc
in
zip
(
resps
,
docs
):
fallback_regexes
=
[]
choice_to_alpha
=
{}
next_alpha
=
'A'
next_alpha
=
"A"
without_paren_fallback_regexes
=
[]
without_paren_to_target
=
{}
multiple_choices_regex
=
re
.
compile
(
r
"\([A-Z]\)([^\n^(]*)"
)
match
=
multiple_choices_regex
.
findall
(
doc
[
'
input
'
])
match
=
multiple_choices_regex
.
findall
(
doc
[
"
input
"
])
for
m
in
match
:
m
=
self
.
filter_ignores
(
m
.
strip
())
fallback_regexes
.
append
(
f
"
{
re
.
escape
(
m
)
}
"
)
...
...
@@ -172,17 +199,23 @@ class MultiChoiceRegexFilter(ExtendedRegexFilter):
without_paren_to_target
[
next_alpha
]
=
f
"(
{
next_alpha
}
)"
next_alpha
=
chr
(
ord
(
next_alpha
)
+
1
)
fallback_regex
=
re
.
compile
(
'|'
.
join
(
fallback_regexes
))
without_paren_fallback_regex
=
'|'
.
join
(
without_paren_fallback_regexes
)
without_paren_fallback_regex
=
re
.
compile
(
f
":[\s]*(
{
without_paren_fallback_regex
}
)"
)
fallback_regex
=
re
.
compile
(
"|"
.
join
(
fallback_regexes
))
without_paren_fallback_regex
=
"|"
.
join
(
without_paren_fallback_regexes
)
without_paren_fallback_regex
=
re
.
compile
(
f
":[\s]*(
{
without_paren_fallback_regex
}
)"
)
filtered
=
[]
for
resp
in
r
:
match
=
self
.
find_match
(
self
.
regex
,
resp
)
if
not
match
:
match
=
self
.
find_match
(
fallback_regex
,
self
.
filter_ignores
(
resp
),
choice_to_alpha
)
match
=
self
.
find_match
(
fallback_regex
,
self
.
filter_ignores
(
resp
),
choice_to_alpha
)
if
not
match
:
match
=
self
.
find_match
(
without_paren_fallback_regex
,
resp
,
without_paren_to_target
)
match
=
self
.
find_match
(
without_paren_fallback_regex
,
resp
,
without_paren_to_target
)
if
not
match
:
match
=
self
.
fallback
filtered
.
append
(
match
)
...
...
lm_eval/tasks/bbh/zeroshot/utils.py
View file @
d27c0c08
import
collections
import
re
import
sys
import
unicodedata
from
lm_eval.filters.extraction
import
Regex
Filter
,
Filter
from
lm_eval.filters.extraction
import
Filter
,
Regex
Filter
class
ExtendedRegexFilter
(
RegexFilter
):
punct_tbl
=
dict
.
fromkeys
(
i
for
i
in
range
(
sys
.
maxunicode
)
if
unicodedata
.
category
(
chr
(
i
)).
startswith
(
'P'
))
punct_tbl
=
dict
.
fromkeys
(
i
for
i
in
range
(
sys
.
maxunicode
)
if
unicodedata
.
category
(
chr
(
i
)).
startswith
(
"P"
)
)
def
__init__
(
self
,
regex_pattern
:
str
=
r
"#### (\-?[0-9\.\,]+)"
,
group_select
=
0
,
fallback
:
str
=
"[invalid]"
,
ignore_case
=
False
,
ignore_punctuation
=
False
,
regexes_to_ignore
=
None
,
self
,
regex_pattern
:
str
=
r
"#### (\-?[0-9\.\,]+)"
,
group_select
=
0
,
fallback
:
str
=
"[invalid]"
,
ignore_case
=
False
,
ignore_punctuation
=
False
,
regexes_to_ignore
=
None
,
)
->
None
:
super
().
__init__
(
regex_pattern
,
group_select
,
fallback
)
self
.
ignore_case
=
ignore_case
...
...
@@ -47,8 +52,13 @@ class ExtendedRegexFilter(RegexFilter):
class
MapRegexFilter
(
ExtendedRegexFilter
):
def
__init__
(
self
,
regex_pattern_to_value
:
dict
=
{},
group_select
=
0
,
fallback
:
str
=
"[invalid]"
,
ignore_case
=
False
,
ignore_punctuation
=
False
,
regexes_to_ignore
=
None
,
self
,
regex_pattern_to_value
:
dict
=
{},
group_select
=
0
,
fallback
:
str
=
"[invalid]"
,
ignore_case
=
False
,
ignore_punctuation
=
False
,
regexes_to_ignore
=
None
,
)
->
None
:
"""
regex_pattern_to_value: Match the regex pattern and change the result into the value
...
...
@@ -57,8 +67,17 @@ class MapRegexFilter(ExtendedRegexFilter):
ignore_punctuation: Remove the punctuation before matching with the given regex
regexes_to_ignore: Remove these regexes before matching with the given regex
"""
super
().
__init__
(
'|'
.
join
(
list
(
regex_pattern_to_value
.
keys
())),
group_select
,
fallback
,
ignore_case
,
ignore_punctuation
,
regexes_to_ignore
)
self
.
regex_to_value
=
{
re
.
compile
(
r
):
v
for
r
,
v
in
regex_pattern_to_value
.
items
()}
super
().
__init__
(
"|"
.
join
(
list
(
regex_pattern_to_value
.
keys
())),
group_select
,
fallback
,
ignore_case
,
ignore_punctuation
,
regexes_to_ignore
,
)
self
.
regex_to_value
=
{
re
.
compile
(
r
):
v
for
r
,
v
in
regex_pattern_to_value
.
items
()
}
def
apply
(
self
,
resps
,
docs
):
filtered_resps
=
[]
...
...
@@ -66,10 +85,15 @@ class MapRegexFilter(ExtendedRegexFilter):
for
r
in
resps
:
filtered
=
[]
for
resp
in
r
:
whole_match_considering_group_select
=
self
.
find_match
(
self
.
regex
,
self
.
filter_ignores
(
resp
))
whole_match_considering_group_select
=
self
.
find_match
(
self
.
regex
,
self
.
filter_ignores
(
resp
)
)
if
whole_match_considering_group_select
:
for
regex
,
mapped_value
in
self
.
regex_to_value
.
items
():
match
=
self
.
find_match
(
regex
,
self
.
filter_ignores
(
whole_match_considering_group_select
))
match
=
self
.
find_match
(
regex
,
self
.
filter_ignores
(
whole_match_considering_group_select
),
)
if
match
:
match
=
mapped_value
break
...
...
@@ -91,9 +115,11 @@ class NumberParseRegexFilter(ExtendedRegexFilter):
filtered_resps
=
[]
import
regex
from
word2number
import
w2n
# https://www.reddit.com/r/regex/comments/11a38uk/parsing_numbers_written_out_as_english_words
english_number_regex
=
regex
.
compile
(
"((?:(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?:|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion)(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?:|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion|[^\S
\r\n
]|,|and|&)+)?(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion))"
)
"((?:(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?:|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion)(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?:|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion|[^\S
\r\n
]|,|and|&)+)?(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion))"
)
for
r
in
resps
:
filtered
=
[]
...
...
@@ -118,21 +144,22 @@ class WordSortFilter(Filter):
filtered_resps
=
[]
for
r
,
doc
in
zip
(
resps
,
docs
):
words
=
doc
[
'
input
'
].
split
(
"List:"
)[
1
].
strip
().
split
()
regex
=
re
.
compile
(
'|'
.
join
([
f
"
\\
b
{
w
}
\\
b"
for
w
in
words
]))
words
=
doc
[
"
input
"
].
split
(
"List:"
)[
1
].
strip
().
split
()
regex
=
re
.
compile
(
"|"
.
join
([
f
"
\\
b
{
w
}
\\
b"
for
w
in
words
]))
filtered
=
[]
for
resp
in
r
:
match
=
regex
.
findall
(
resp
)
match
.
reverse
()
ordered_words
=
reversed
(
collections
.
OrderedDict
(
zip
(
match
,
[
None
]
*
len
(
match
))))
filtered
.
append
(
' '
.
join
(
ordered_words
))
ordered_words
=
reversed
(
collections
.
OrderedDict
(
zip
(
match
,
[
None
]
*
len
(
match
)))
)
filtered
.
append
(
" "
.
join
(
ordered_words
))
filtered_resps
.
append
(
filtered
)
return
filtered_resps
class
MultiChoiceRegexFilter
(
ExtendedRegexFilter
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
"""
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
...
...
@@ -156,13 +183,13 @@ class MultiChoiceRegexFilter(ExtendedRegexFilter):
for
r
,
doc
in
zip
(
resps
,
docs
):
fallback_regexes
=
[]
choice_to_alpha
=
{}
next_alpha
=
'A'
next_alpha
=
"A"
without_paren_fallback_regexes
=
[]
without_paren_to_target
=
{}
multiple_choices_regex
=
re
.
compile
(
r
"\([A-Z]\)([^\n^(]*)"
)
match
=
multiple_choices_regex
.
findall
(
doc
[
'
input
'
])
match
=
multiple_choices_regex
.
findall
(
doc
[
"
input
"
])
for
m
in
match
:
m
=
self
.
filter_ignores
(
m
.
strip
())
fallback_regexes
.
append
(
f
"
{
re
.
escape
(
m
)
}
"
)
...
...
@@ -172,17 +199,23 @@ class MultiChoiceRegexFilter(ExtendedRegexFilter):
without_paren_to_target
[
next_alpha
]
=
f
"(
{
next_alpha
}
)"
next_alpha
=
chr
(
ord
(
next_alpha
)
+
1
)
fallback_regex
=
re
.
compile
(
'|'
.
join
(
fallback_regexes
))
without_paren_fallback_regex
=
'|'
.
join
(
without_paren_fallback_regexes
)
without_paren_fallback_regex
=
re
.
compile
(
f
":[\s]*(
{
without_paren_fallback_regex
}
)"
)
fallback_regex
=
re
.
compile
(
"|"
.
join
(
fallback_regexes
))
without_paren_fallback_regex
=
"|"
.
join
(
without_paren_fallback_regexes
)
without_paren_fallback_regex
=
re
.
compile
(
f
":[\s]*(
{
without_paren_fallback_regex
}
)"
)
filtered
=
[]
for
resp
in
r
:
match
=
self
.
find_match
(
self
.
regex
,
resp
)
if
not
match
:
match
=
self
.
find_match
(
fallback_regex
,
self
.
filter_ignores
(
resp
),
choice_to_alpha
)
match
=
self
.
find_match
(
fallback_regex
,
self
.
filter_ignores
(
resp
),
choice_to_alpha
)
if
not
match
:
match
=
self
.
find_match
(
without_paren_fallback_regex
,
resp
,
without_paren_to_target
)
match
=
self
.
find_match
(
without_paren_fallback_regex
,
resp
,
without_paren_to_target
)
if
not
match
:
match
=
self
.
fallback
filtered
.
append
(
match
)
...
...
lm_eval/tasks/belebele/_generate_configs.py
View file @
d27c0c08
"""
Take in a YAML, and output all other splits with this YAML
"""
import
os
import
yaml
import
argparse
import
request
s
import
o
s
import
requests
import
yaml
from
tqdm
import
tqdm
from
lm_eval.utils
import
logging
API_URL
=
"https://datasets-server.huggingface.co/splits?dataset=facebook/belebele"
...
...
@@ -39,6 +40,7 @@ if __name__ == "__main__":
def
query
():
response
=
requests
.
get
(
API_URL
)
return
response
.
json
()[
"splits"
]
print
(
query
())
languages
=
[
split
[
"split"
]
for
split
in
query
()]
...
...
@@ -49,7 +51,7 @@ if __name__ == "__main__":
if
args
.
task_prefix
!=
""
else
f
"belebele_
{
lang
}
"
,
"test_split"
:
lang
,
"fewshot_split"
:
lang
,
"fewshot_split"
:
lang
,
}
file_save_path
=
args
.
save_prefix_path
+
f
"_
{
lang
}
.yaml"
...
...
lm_eval/tasks/bigbench/generate_tasks.py
View file @
d27c0c08
import
os
import
yaml
all_subtasks
=
[
"abstract_narrative_understanding"
,
"anachronisms"
,
...
...
lm_eval/tasks/bigbench/push_bigbench_dataset.py
View file @
d27c0c08
...
...
@@ -8,10 +8,9 @@ Requires the installation of
`pip install "bigbench @ https://storage.googleapis.com/public_research_data/bigbench/bigbench-0.0.1.tar.gz"`
and is included so that the bigbench dependency can be avoided.
"""
from
tqdm
import
tqdm
import
datasets
import
bigbench.api.util
as
bb_utils
import
datasets
from
tqdm
import
tqdm
all_task_names
=
bb_utils
.
get_all_json_task_names
()
...
...
lm_eval/tasks/blimp/generate_configs.py
View file @
d27c0c08
import
yaml
all_subtasks
=
[
"adjunct_island"
,
"anaphor_gender_agreement"
,
...
...
lm_eval/tasks/ceval/_generate_configs.py
View file @
d27c0c08
"""
Take in a YAML, and output all other splits with this YAML
"""
import
os
import
yaml
import
argparse
import
os
import
yaml
from
tqdm
import
tqdm
from
lm_eval.logger
import
eval_logger
SUBJECTS
=
{
"computer_network"
:
"计算机网络"
,
"operating_system"
:
"操作系统"
,
...
...
lm_eval/tasks/cmmlu/_generate_configs.py
View file @
d27c0c08
"""
Take in a YAML, and output all other splits with this YAML
"""
import
os
import
yaml
import
argparse
import
os
import
yaml
from
tqdm
import
tqdm
from
lm_eval.logger
import
eval_logger
SUBJECTS
=
{
"agronomy"
:
"农学"
,
"anatomy"
:
"解剖学"
,
...
...
lm_eval/tasks/code_x_glue/code-text/bleu.py
View file @
d27c0c08
#!/usr/bin/python
import
math
import
re
import
sys
import
math
import
xml.sax.saxutils
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Pattern
,
Tuple
,
Union
from
typing
import
List
,
Pattern
,
Tuple
,
Union
,
Dict
,
Any
,
Optional
"""
This script was adapted from the original version by hieuhoang1972 which is part of MOSES.
...
...
@@ -60,7 +60,7 @@ def normalize(s):
# Added to bypass NIST-style pre-processing of hyp and ref files -- wade
if
nonorm
:
return
s
.
split
()
if
type
(
s
)
is
not
str
:
if
not
isinstance
(
s
,
str
)
:
s
=
" "
.
join
(
s
)
# language-independent part:
for
pattern
,
replace
in
normalize1
:
...
...
lm_eval/tasks/csatqa/_generate_configs.py
View file @
d27c0c08
"""
Take in a YAML, and output all other splits with this YAML
"""
import
os
import
yaml
import
argparse
import
os
import
yaml
from
tqdm
import
tqdm
from
lm_eval.logger
import
eval_logger
SUBSETS
=
[
"WR"
,
"GR"
,
"RCS"
,
"RCSS"
,
"RCH"
,
"LI"
]
...
...
lm_eval/tasks/drop/utils.py
View file @
d27c0c08
...
...
@@ -4,6 +4,7 @@ import string
import
numpy
as
np
from
scipy.optimize
import
linear_sum_assignment
_ARTICLES
=
re
.
compile
(
r
"\b(a|an|the)\b"
,
re
.
UNICODE
)
...
...
lm_eval/tasks/gpqa/n_shot/_generate_configs.py
View file @
d27c0c08
import
yaml
from
tqdm
import
tqdm
...
...
@@ -22,5 +21,6 @@ def main() -> None:
except
FileExistsError
:
pass
if
__name__
==
"__main__"
:
main
()
lm_eval/tasks/gpqa/n_shot/utils.py
View file @
d27c0c08
import
datasets
import
re
import
random
import
re
import
datasets
def
preprocess
(
text
):
if
text
is
None
:
...
...
@@ -11,8 +13,10 @@ def preprocess(text):
text
=
text
.
replace
(
" "
,
" "
)
return
text
rng
=
random
.
Random
(
42
)
def
process_docs
(
dataset
:
datasets
.
Dataset
)
->
datasets
.
Dataset
:
def
_process_doc
(
doc
):
choices
=
[
...
...
@@ -30,7 +34,7 @@ def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
"choice2"
:
choices
[
1
],
"choice3"
:
choices
[
2
],
"choice4"
:
choices
[
3
],
"answer"
:
f
"(
{
chr
(
65
+
correct_answer_index
)
}
)"
"answer"
:
f
"(
{
chr
(
65
+
correct_answer_index
)
}
)"
,
}
return
out_doc
...
...
lm_eval/tasks/gpqa/zeroshot/_generate_configs.py
View file @
d27c0c08
import
yaml
from
tqdm
import
tqdm
...
...
@@ -22,5 +21,6 @@ def main() -> None:
except
FileExistsError
:
pass
if
__name__
==
"__main__"
:
main
()
lm_eval/tasks/gpqa/zeroshot/utils.py
View file @
d27c0c08
import
datasets
import
re
import
random
import
re
import
datasets
def
preprocess
(
text
):
if
text
is
None
:
...
...
@@ -29,7 +31,7 @@ def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
"choice2"
:
choices
[
1
],
"choice3"
:
choices
[
2
],
"choice4"
:
choices
[
3
],
"answer"
:
f
"(
{
chr
(
65
+
correct_answer_index
)
}
)"
"answer"
:
f
"(
{
chr
(
65
+
correct_answer_index
)
}
)"
,
}
return
out_doc
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment