Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
d27c0c08
Unverified
Commit
d27c0c08
authored
Feb 26, 2024
by
LSinev
Committed by
GitHub
Feb 26, 2024
Browse files
Apply code autoformatting with Ruff to tasks/*.py an *__init__.py (#1469)
parent
f78e2da4
Changes
48
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
252 additions
and
150 deletions
+252
-150
lm_eval/filters/__init__.py
lm_eval/filters/__init__.py
+3
-4
lm_eval/models/__init__.py
lm_eval/models/__init__.py
+14
-10
lm_eval/prompts/__init__.py
lm_eval/prompts/__init__.py
+3
-2
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+75
-57
lm_eval/tasks/bbh/_generate_configs.py
lm_eval/tasks/bbh/_generate_configs.py
+3
-3
lm_eval/tasks/bbh/cot_zeroshot/utils.py
lm_eval/tasks/bbh/cot_zeroshot/utils.py
+58
-25
lm_eval/tasks/bbh/zeroshot/utils.py
lm_eval/tasks/bbh/zeroshot/utils.py
+58
-25
lm_eval/tasks/belebele/_generate_configs.py
lm_eval/tasks/belebele/_generate_configs.py
+6
-4
lm_eval/tasks/bigbench/generate_tasks.py
lm_eval/tasks/bigbench/generate_tasks.py
+2
-0
lm_eval/tasks/bigbench/push_bigbench_dataset.py
lm_eval/tasks/bigbench/push_bigbench_dataset.py
+2
-3
lm_eval/tasks/blimp/generate_configs.py
lm_eval/tasks/blimp/generate_configs.py
+1
-0
lm_eval/tasks/ceval/_generate_configs.py
lm_eval/tasks/ceval/_generate_configs.py
+3
-2
lm_eval/tasks/cmmlu/_generate_configs.py
lm_eval/tasks/cmmlu/_generate_configs.py
+3
-2
lm_eval/tasks/code_x_glue/code-text/bleu.py
lm_eval/tasks/code_x_glue/code-text/bleu.py
+3
-3
lm_eval/tasks/csatqa/_generate_configs.py
lm_eval/tasks/csatqa/_generate_configs.py
+3
-2
lm_eval/tasks/drop/utils.py
lm_eval/tasks/drop/utils.py
+1
-0
lm_eval/tasks/gpqa/n_shot/_generate_configs.py
lm_eval/tasks/gpqa/n_shot/_generate_configs.py
+1
-1
lm_eval/tasks/gpqa/n_shot/utils.py
lm_eval/tasks/gpqa/n_shot/utils.py
+7
-3
lm_eval/tasks/gpqa/zeroshot/_generate_configs.py
lm_eval/tasks/gpqa/zeroshot/_generate_configs.py
+1
-1
lm_eval/tasks/gpqa/zeroshot/utils.py
lm_eval/tasks/gpqa/zeroshot/utils.py
+5
-3
No files found.
lm_eval/filters/__init__.py
View file @
d27c0c08
from
typing
import
List
,
Union
from
functools
import
partial
from
functools
import
partial
from
typing
import
List
,
Union
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.filter
import
FilterEnsemble
from
.
import
selection
from
.
import
extraction
from
.
import
extraction
,
selection
,
transformation
from
.
import
transformation
FILTER_REGISTRY
=
{
FILTER_REGISTRY
=
{
...
...
lm_eval/models/__init__.py
View file @
d27c0c08
from
.
import
huggingface
from
.
import
(
from
.
import
openai_completions
anthropic_llms
,
from
.
import
textsynth
dummy
,
from
.
import
dummy
gguf
,
from
.
import
anthropic_llms
huggingface
,
from
.
import
gguf
mamba_lm
,
from
.
import
vllm_causallms
neuron_optimum
,
from
.
import
mamba_lm
openai_completions
,
from
.
import
optimum_lm
optimum_lm
,
from
.
import
neuron_optimum
textsynth
,
vllm_causallms
,
)
# TODO: implement __all__
# TODO: implement __all__
...
...
lm_eval/prompts/__init__.py
View file @
d27c0c08
import
os
import
ast
import
ast
import
os
from
typing
import
Dict
from
typing
import
Dict
from
lm_eval
import
utils
from
lm_eval
import
utils
from
lm_eval.utils
import
eval_logger
from
lm_eval.utils
import
eval_logger
# Prompt library.
# Prompt library.
# Stores prompts in a dictionary indexed by 2 levels:
# Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name.
# prompt category name, and prompt name.
...
...
lm_eval/tasks/__init__.py
View file @
d27c0c08
import
os
import
abc
import
abc
import
collections
import
collections
import
logging
import
os
from
functools
import
partial
from
functools
import
partial
from
typing
import
List
,
Union
,
Dict
from
typing
import
Dict
,
List
,
Union
from
lm_eval
import
utils
from
lm_eval
import
utils
from
lm_eval.api.task
import
Task
,
ConfigurableTask
from
lm_eval.api.task
import
ConfigurableTask
,
Task
import
logging
class
TaskManager
:
class
TaskManager
:
...
@@ -16,20 +14,14 @@ class TaskManager:
...
@@ -16,20 +14,14 @@ class TaskManager:
and an optional directory if provided.
and an optional directory if provided.
"""
"""
def
__init__
(
self
,
verbosity
=
"INFO"
,
include_path
=
None
)
->
None
:
def
__init__
(
self
,
verbosity
=
"INFO"
,
include_path
=
None
)
->
None
:
self
.
verbosity
=
verbosity
self
.
verbosity
=
verbosity
self
.
include_path
=
include_path
self
.
include_path
=
include_path
self
.
logger
=
utils
.
eval_logger
self
.
logger
=
utils
.
eval_logger
self
.
logger
.
setLevel
(
getattr
(
logging
,
f
"
{
verbosity
}
"
))
self
.
logger
.
setLevel
(
getattr
(
logging
,
f
"
{
verbosity
}
"
))
self
.
_task_index
=
self
.
initialize_tasks
(
self
.
_task_index
=
self
.
initialize_tasks
(
include_path
=
include_path
)
include_path
=
include_path
)
self
.
_all_tasks
=
sorted
(
list
(
self
.
_task_index
.
keys
()))
self
.
_all_tasks
=
sorted
(
list
(
self
.
_task_index
.
keys
()))
self
.
task_group_map
=
collections
.
defaultdict
(
list
)
self
.
task_group_map
=
collections
.
defaultdict
(
list
)
...
@@ -65,27 +57,29 @@ class TaskManager:
...
@@ -65,27 +57,29 @@ class TaskManager:
return
self
.
_task_index
return
self
.
_task_index
def
match_tasks
(
self
,
task_list
):
def
match_tasks
(
self
,
task_list
):
return
utils
.
pattern_match
(
return
utils
.
pattern_match
(
task_list
,
self
.
all_tasks
)
task_list
,
self
.
all_tasks
)
def
_name_is_registered
(
self
,
name
):
def
_name_is_registered
(
self
,
name
):
if
name
in
self
.
all_tasks
:
if
name
in
self
.
all_tasks
:
return
True
return
True
return
False
return
False
def
_name_is_task
(
self
,
name
):
def
_name_is_task
(
self
,
name
)
->
bool
:
if
self
.
_name_is_registered
(
name
)
and
(
"task"
in
self
.
task_index
[
name
][
"type"
]):
if
self
.
_name_is_registered
(
name
)
and
(
"task"
in
self
.
task_index
[
name
][
"type"
]):
return
True
return
True
return
False
return
False
def
_name_is_group
(
self
,
name
):
def
_name_is_group
(
self
,
name
):
if
self
.
_name_is_registered
(
name
)
and
(
self
.
task_index
[
name
][
"type"
]
==
"group"
):
if
self
.
_name_is_registered
(
name
)
and
(
self
.
task_index
[
name
][
"type"
]
==
"group"
):
return
True
return
True
return
False
return
False
def
_name_is_python_task
(
self
,
name
):
def
_name_is_python_task
(
self
,
name
):
if
self
.
_name_is_registered
(
name
)
and
(
self
.
task_index
[
name
][
"type"
]
==
"python_task"
):
if
self
.
_name_is_registered
(
name
)
and
(
self
.
task_index
[
name
][
"type"
]
==
"python_task"
):
return
True
return
True
return
False
return
False
...
@@ -117,7 +111,7 @@ class TaskManager:
...
@@ -117,7 +111,7 @@ class TaskManager:
return
utils
.
load_yaml_config
(
yaml_path
,
mode
=
"full"
)
return
utils
.
load_yaml_config
(
yaml_path
,
mode
=
"full"
)
def
_get_tasklist
(
self
,
name
):
def
_get_tasklist
(
self
,
name
):
assert
self
.
_name_is_task
(
name
)
==
False
assert
self
.
_name_is_task
(
name
)
is
False
return
self
.
task_index
[
name
][
"task"
]
return
self
.
task_index
[
name
][
"task"
]
def
_process_alias
(
self
,
config
,
group
=
None
):
def
_process_alias
(
self
,
config
,
group
=
None
):
...
@@ -130,12 +124,12 @@ class TaskManager:
...
@@ -130,12 +124,12 @@ class TaskManager:
return
config
return
config
def
_load_individual_task_or_group
(
def
_load_individual_task_or_group
(
self
,
self
,
name_or_config
:
Union
[
str
,
dict
]
=
None
,
name_or_config
:
Union
[
str
,
dict
]
=
None
,
parent_name
:
str
=
None
,
parent_name
:
str
=
None
,
update_config
:
dict
=
None
,
update_config
:
dict
=
None
,
yaml_path
:
str
=
None
,
yaml_path
:
str
=
None
,
)
->
ConfigurableTask
:
)
->
ConfigurableTask
:
def
load_task
(
config
,
task
,
group
=
None
,
yaml_path
=
None
):
def
load_task
(
config
,
task
,
group
=
None
,
yaml_path
=
None
):
if
"include"
in
config
:
if
"include"
in
config
:
assert
yaml_path
is
not
None
assert
yaml_path
is
not
None
...
@@ -174,7 +168,9 @@ class TaskManager:
...
@@ -174,7 +168,9 @@ class TaskManager:
group_config
=
self
.
_get_config
(
name_or_config
)
group_config
=
self
.
_get_config
(
name_or_config
)
if
set
(
group_config
.
keys
())
>
set
([
"task"
,
"group"
]):
if
set
(
group_config
.
keys
())
>
set
([
"task"
,
"group"
]):
update_config
=
{
update_config
=
{
k
:
v
for
k
,
v
in
group_config
.
items
()
if
k
not
in
[
"task"
,
"group"
]
k
:
v
for
k
,
v
in
group_config
.
items
()
if
k
not
in
[
"task"
,
"group"
]
}
}
yaml_path
=
self
.
_get_yaml_path
(
group_name
)
yaml_path
=
self
.
_get_yaml_path
(
group_name
)
...
@@ -183,9 +179,8 @@ class TaskManager:
...
@@ -183,9 +179,8 @@ class TaskManager:
update_config
.
pop
(
"group_alias"
)
update_config
.
pop
(
"group_alias"
)
if
isinstance
(
name_or_config
,
dict
):
if
isinstance
(
name_or_config
,
dict
):
if
update_config
is
not
None
:
if
update_config
is
not
None
:
name_or_config
=
{
name_or_config
=
{
**
name_or_config
,
**
name_or_config
,
**
update_config
,
**
update_config
,
}
}
...
@@ -196,7 +191,9 @@ class TaskManager:
...
@@ -196,7 +191,9 @@ class TaskManager:
# if self._name_is_task(name) is False:
# if self._name_is_task(name) is False:
if
self
.
_name_is_group
(
name
):
if
self
.
_name_is_group
(
name
):
group_name
=
name
group_name
=
name
update_config
=
{
k
:
v
for
k
,
v
in
name_or_config
.
items
()
if
k
!=
"task"
}
update_config
=
{
k
:
v
for
k
,
v
in
name_or_config
.
items
()
if
k
!=
"task"
}
subtask_list
=
self
.
_get_tasklist
(
name
)
subtask_list
=
self
.
_get_tasklist
(
name
)
if
subtask_list
==
-
1
:
if
subtask_list
==
-
1
:
subtask_list
=
self
.
_get_config
(
name
)[
"task"
]
subtask_list
=
self
.
_get_config
(
name
)[
"task"
]
...
@@ -207,36 +204,53 @@ class TaskManager:
...
@@ -207,36 +204,53 @@ class TaskManager:
# Check if this is a duplicate.
# Check if this is a duplicate.
if
parent_name
is
not
None
:
if
parent_name
is
not
None
:
name_or_config
[
"group"
]
=
parent_name
name_or_config
[
"group"
]
=
parent_name
num_duplicate
=
len
(
list
(
filter
(
lambda
x
:
x
.
startswith
(
name
),
self
.
task_group_map
[
parent_name
])))
num_duplicate
=
len
(
list
(
filter
(
lambda
x
:
x
.
startswith
(
name
),
self
.
task_group_map
[
parent_name
],
)
)
)
if
num_duplicate
>
0
:
if
num_duplicate
>
0
:
name
=
f
"
{
name
}
-
{
num_duplicate
}
"
name
=
f
"
{
name
}
-
{
num_duplicate
}
"
self
.
task_group_map
[
parent_name
].
append
(
name
)
self
.
task_group_map
[
parent_name
].
append
(
name
)
task_config
=
{
task_config
=
{
**
base_task_config
,
**
base_task_config
,
**
name_or_config
,
**
name_or_config
,
}
}
else
:
else
:
task_config
=
name_or_config
task_config
=
name_or_config
return
load_task
(
task_config
,
task
=
name
,
group
=
parent_name
,
yaml_path
=
yaml_path
)
return
load_task
(
task_config
,
task
=
name
,
group
=
parent_name
,
yaml_path
=
yaml_path
)
else
:
else
:
group_name
=
name_or_config
[
"group"
]
group_name
=
name_or_config
[
"group"
]
subtask_list
=
name_or_config
[
"task"
]
subtask_list
=
name_or_config
[
"task"
]
# update_config = {k:v for k,v in name_or_config.items() if k != "task"}
if
set
(
name_or_config
.
keys
())
>
set
([
"task"
,
"group"
]):
if
set
(
name_or_config
.
keys
())
>
set
([
"task"
,
"group"
]):
update_config
=
{
update_config
=
{
k
:
v
for
k
,
v
in
name_or_config
.
items
()
if
k
not
in
[
"task"
,
"group"
]
k
:
v
for
k
,
v
in
name_or_config
.
items
()
if
k
not
in
[
"task"
,
"group"
]
}
}
all_subtasks
=
{}
all_subtasks
=
{}
if
(
parent_name
is
not
None
)
:
if
parent_name
is
not
None
:
all_subtasks
=
{
group_name
:
(
parent_name
,
None
)}
all_subtasks
=
{
group_name
:
(
parent_name
,
None
)}
fn
=
partial
(
self
.
_load_individual_task_or_group
,
parent_name
=
group_name
,
update_config
=
update_config
,
yaml_path
=
yaml_path
)
fn
=
partial
(
all_subtasks
=
{
**
all_subtasks
,
**
dict
(
collections
.
ChainMap
(
*
map
(
fn
,
subtask_list
)))}
self
.
_load_individual_task_or_group
,
parent_name
=
group_name
,
update_config
=
update_config
,
yaml_path
=
yaml_path
,
)
all_subtasks
=
{
**
all_subtasks
,
**
dict
(
collections
.
ChainMap
(
*
map
(
fn
,
subtask_list
))),
}
return
all_subtasks
return
all_subtasks
def
load_task_or_group
(
self
,
task_list
:
Union
[
str
,
list
]
=
None
)
->
dict
:
def
load_task_or_group
(
self
,
task_list
:
Union
[
str
,
list
]
=
None
)
->
dict
:
"""Loads a dictionary of task objects from a list
"""Loads a dictionary of task objects from a list
...
@@ -250,12 +264,7 @@ class TaskManager:
...
@@ -250,12 +264,7 @@ class TaskManager:
task_list
=
[
task_list
]
task_list
=
[
task_list
]
all_loaded_tasks
=
dict
(
all_loaded_tasks
=
dict
(
collections
.
ChainMap
(
collections
.
ChainMap
(
*
map
(
self
.
_load_individual_task_or_group
,
task_list
))
*
map
(
self
.
_load_individual_task_or_group
,
task_list
)
)
)
)
return
all_loaded_tasks
return
all_loaded_tasks
...
@@ -299,11 +308,11 @@ class TaskManager:
...
@@ -299,11 +308,11 @@ class TaskManager:
# This is a group config
# This is a group config
tasks_and_groups
[
config
[
"group"
]]
=
{
tasks_and_groups
[
config
[
"group"
]]
=
{
"type"
:
"group"
,
"type"
:
"group"
,
"task"
:
-
1
,
# This signals that
"task"
:
-
1
,
# This signals that
# we don't need to know
# we don't need to know
# the task list for indexing
# the task list for indexing
# as it can be loaded
# as it can be loaded
# when called.
# when called.
"yaml_path"
:
yaml_path
,
"yaml_path"
:
yaml_path
,
}
}
...
@@ -322,7 +331,7 @@ class TaskManager:
...
@@ -322,7 +331,7 @@ class TaskManager:
tasks_and_groups
[
task
]
=
{
tasks_and_groups
[
task
]
=
{
"type"
:
"task"
,
"type"
:
"task"
,
"yaml_path"
:
yaml_path
,
"yaml_path"
:
yaml_path
,
}
}
if
"group"
in
config
:
if
"group"
in
config
:
groups
=
config
[
"group"
]
groups
=
config
[
"group"
]
...
@@ -343,6 +352,7 @@ class TaskManager:
...
@@ -343,6 +352,7 @@ class TaskManager:
return
tasks_and_groups
return
tasks_and_groups
def
include_path
(
task_dir
):
def
include_path
(
task_dir
):
logger
=
utils
.
eval_logger
logger
=
utils
.
eval_logger
logger
.
setLevel
(
getattr
(
logging
,
"INFO"
))
logger
.
setLevel
(
getattr
(
logging
,
"INFO"
))
...
@@ -352,6 +362,7 @@ def include_path(task_dir):
...
@@ -352,6 +362,7 @@ def include_path(task_dir):
)
)
return
0
return
0
def
initialize_tasks
(
verbosity
=
"INFO"
):
def
initialize_tasks
(
verbosity
=
"INFO"
):
logger
=
utils
.
eval_logger
logger
=
utils
.
eval_logger
logger
.
setLevel
(
getattr
(
logging
,
f
"
{
verbosity
}
"
))
logger
.
setLevel
(
getattr
(
logging
,
f
"
{
verbosity
}
"
))
...
@@ -362,6 +373,7 @@ def initialize_tasks(verbosity="INFO"):
...
@@ -362,6 +373,7 @@ def initialize_tasks(verbosity="INFO"):
)
)
return
0
return
0
def
get_task_name_from_config
(
task_config
:
Dict
[
str
,
str
])
->
str
:
def
get_task_name_from_config
(
task_config
:
Dict
[
str
,
str
])
->
str
:
if
"task"
in
task_config
:
if
"task"
in
task_config
:
return
task_config
[
"task"
]
return
task_config
[
"task"
]
...
@@ -370,6 +382,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str:
...
@@ -370,6 +382,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str:
else
:
else
:
return
"{dataset_path}"
.
format
(
**
task_config
)
return
"{dataset_path}"
.
format
(
**
task_config
)
def
get_task_name_from_object
(
task_object
):
def
get_task_name_from_object
(
task_object
):
if
hasattr
(
task_object
,
"config"
):
if
hasattr
(
task_object
,
"config"
):
return
task_object
.
_config
[
"task"
]
return
task_object
.
_config
[
"task"
]
...
@@ -382,7 +395,10 @@ def get_task_name_from_object(task_object):
...
@@ -382,7 +395,10 @@ def get_task_name_from_object(task_object):
else
type
(
task_object
).
__name__
else
type
(
task_object
).
__name__
)
)
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
Dict
,
Task
]],
task_manager
:
TaskManager
=
None
):
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
Dict
,
Task
]],
task_manager
:
TaskManager
=
None
):
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
:param task_name_list: List[Union[str, Dict, Task]]
:param task_name_list: List[Union[str, Dict, Task]]
...
@@ -409,7 +425,9 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], task_manager: Ta
...
@@ -409,7 +425,9 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], task_manager: Ta
if
task_manager
is
None
:
if
task_manager
is
None
:
task_manager
=
TaskManager
()
task_manager
=
TaskManager
()
task_name_from_string_dict
=
task_manager
.
load_task_or_group
(
string_task_name_list
)
task_name_from_string_dict
=
task_manager
.
load_task_or_group
(
string_task_name_list
)
for
task_element
in
others_task_name_list
:
for
task_element
in
others_task_name_list
:
if
isinstance
(
task_element
,
dict
):
if
isinstance
(
task_element
,
dict
):
...
...
lm_eval/tasks/bbh/_generate_configs.py
View file @
d27c0c08
"""
"""
Take in a YAML, and output all other splits with this YAML
Take in a YAML, and output all other splits with this YAML
"""
"""
import
argparse
import
os
import
os
import
re
import
re
import
yaml
import
requests
import
argparse
import
datasets
import
datasets
import
requests
import
yaml
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
lm_eval
import
utils
from
lm_eval
import
utils
...
...
lm_eval/tasks/bbh/cot_zeroshot/utils.py
View file @
d27c0c08
import
collections
import
collections
import
re
import
re
import
sys
import
sys
import
unicodedata
import
unicodedata
from
lm_eval.filters.extraction
import
Regex
Filter
,
Filter
from
lm_eval.filters.extraction
import
Filter
,
Regex
Filter
class
ExtendedRegexFilter
(
RegexFilter
):
class
ExtendedRegexFilter
(
RegexFilter
):
punct_tbl
=
dict
.
fromkeys
(
i
for
i
in
range
(
sys
.
maxunicode
)
punct_tbl
=
dict
.
fromkeys
(
if
unicodedata
.
category
(
chr
(
i
)).
startswith
(
'P'
))
i
for
i
in
range
(
sys
.
maxunicode
)
if
unicodedata
.
category
(
chr
(
i
)).
startswith
(
"P"
)
)
def
__init__
(
def
__init__
(
self
,
regex_pattern
:
str
=
r
"#### (\-?[0-9\.\,]+)"
,
group_select
=
0
,
fallback
:
str
=
"[invalid]"
,
self
,
ignore_case
=
False
,
ignore_punctuation
=
False
,
regexes_to_ignore
=
None
,
regex_pattern
:
str
=
r
"#### (\-?[0-9\.\,]+)"
,
group_select
=
0
,
fallback
:
str
=
"[invalid]"
,
ignore_case
=
False
,
ignore_punctuation
=
False
,
regexes_to_ignore
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
(
regex_pattern
,
group_select
,
fallback
)
super
().
__init__
(
regex_pattern
,
group_select
,
fallback
)
self
.
ignore_case
=
ignore_case
self
.
ignore_case
=
ignore_case
...
@@ -47,8 +52,13 @@ class ExtendedRegexFilter(RegexFilter):
...
@@ -47,8 +52,13 @@ class ExtendedRegexFilter(RegexFilter):
class
MapRegexFilter
(
ExtendedRegexFilter
):
class
MapRegexFilter
(
ExtendedRegexFilter
):
def
__init__
(
def
__init__
(
self
,
regex_pattern_to_value
:
dict
=
{},
group_select
=
0
,
fallback
:
str
=
"[invalid]"
,
self
,
ignore_case
=
False
,
ignore_punctuation
=
False
,
regexes_to_ignore
=
None
,
regex_pattern_to_value
:
dict
=
{},
group_select
=
0
,
fallback
:
str
=
"[invalid]"
,
ignore_case
=
False
,
ignore_punctuation
=
False
,
regexes_to_ignore
=
None
,
)
->
None
:
)
->
None
:
"""
"""
regex_pattern_to_value: Match the regex pattern and change the result into the value
regex_pattern_to_value: Match the regex pattern and change the result into the value
...
@@ -57,8 +67,17 @@ class MapRegexFilter(ExtendedRegexFilter):
...
@@ -57,8 +67,17 @@ class MapRegexFilter(ExtendedRegexFilter):
ignore_punctuation: Remove the punctuation before matching with the given regex
ignore_punctuation: Remove the punctuation before matching with the given regex
regexes_to_ignore: Remove these regexes before matching with the given regex
regexes_to_ignore: Remove these regexes before matching with the given regex
"""
"""
super
().
__init__
(
'|'
.
join
(
list
(
regex_pattern_to_value
.
keys
())),
group_select
,
fallback
,
ignore_case
,
ignore_punctuation
,
regexes_to_ignore
)
super
().
__init__
(
self
.
regex_to_value
=
{
re
.
compile
(
r
):
v
for
r
,
v
in
regex_pattern_to_value
.
items
()}
"|"
.
join
(
list
(
regex_pattern_to_value
.
keys
())),
group_select
,
fallback
,
ignore_case
,
ignore_punctuation
,
regexes_to_ignore
,
)
self
.
regex_to_value
=
{
re
.
compile
(
r
):
v
for
r
,
v
in
regex_pattern_to_value
.
items
()
}
def
apply
(
self
,
resps
,
docs
):
def
apply
(
self
,
resps
,
docs
):
filtered_resps
=
[]
filtered_resps
=
[]
...
@@ -66,10 +85,15 @@ class MapRegexFilter(ExtendedRegexFilter):
...
@@ -66,10 +85,15 @@ class MapRegexFilter(ExtendedRegexFilter):
for
r
in
resps
:
for
r
in
resps
:
filtered
=
[]
filtered
=
[]
for
resp
in
r
:
for
resp
in
r
:
whole_match_considering_group_select
=
self
.
find_match
(
self
.
regex
,
self
.
filter_ignores
(
resp
))
whole_match_considering_group_select
=
self
.
find_match
(
self
.
regex
,
self
.
filter_ignores
(
resp
)
)
if
whole_match_considering_group_select
:
if
whole_match_considering_group_select
:
for
regex
,
mapped_value
in
self
.
regex_to_value
.
items
():
for
regex
,
mapped_value
in
self
.
regex_to_value
.
items
():
match
=
self
.
find_match
(
regex
,
self
.
filter_ignores
(
whole_match_considering_group_select
))
match
=
self
.
find_match
(
regex
,
self
.
filter_ignores
(
whole_match_considering_group_select
),
)
if
match
:
if
match
:
match
=
mapped_value
match
=
mapped_value
break
break
...
@@ -91,9 +115,11 @@ class NumberParseRegexFilter(ExtendedRegexFilter):
...
@@ -91,9 +115,11 @@ class NumberParseRegexFilter(ExtendedRegexFilter):
filtered_resps
=
[]
filtered_resps
=
[]
import
regex
import
regex
from
word2number
import
w2n
from
word2number
import
w2n
# https://www.reddit.com/r/regex/comments/11a38uk/parsing_numbers_written_out_as_english_words
# https://www.reddit.com/r/regex/comments/11a38uk/parsing_numbers_written_out_as_english_words
english_number_regex
=
regex
.
compile
(
english_number_regex
=
regex
.
compile
(
"((?:(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?:|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion)(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?:|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion|[^\S
\r\n
]|,|and|&)+)?(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion))"
)
"((?:(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?:|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion)(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?:|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion|[^\S
\r\n
]|,|and|&)+)?(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion))"
)
for
r
in
resps
:
for
r
in
resps
:
filtered
=
[]
filtered
=
[]
...
@@ -118,21 +144,22 @@ class WordSortFilter(Filter):
...
@@ -118,21 +144,22 @@ class WordSortFilter(Filter):
filtered_resps
=
[]
filtered_resps
=
[]
for
r
,
doc
in
zip
(
resps
,
docs
):
for
r
,
doc
in
zip
(
resps
,
docs
):
words
=
doc
[
'
input
'
].
split
(
"List:"
)[
1
].
strip
().
split
()
words
=
doc
[
"
input
"
].
split
(
"List:"
)[
1
].
strip
().
split
()
regex
=
re
.
compile
(
'|'
.
join
([
f
"
\\
b
{
w
}
\\
b"
for
w
in
words
]))
regex
=
re
.
compile
(
"|"
.
join
([
f
"
\\
b
{
w
}
\\
b"
for
w
in
words
]))
filtered
=
[]
filtered
=
[]
for
resp
in
r
:
for
resp
in
r
:
match
=
regex
.
findall
(
resp
)
match
=
regex
.
findall
(
resp
)
match
.
reverse
()
match
.
reverse
()
ordered_words
=
reversed
(
collections
.
OrderedDict
(
zip
(
match
,
[
None
]
*
len
(
match
))))
ordered_words
=
reversed
(
filtered
.
append
(
' '
.
join
(
ordered_words
))
collections
.
OrderedDict
(
zip
(
match
,
[
None
]
*
len
(
match
)))
)
filtered
.
append
(
" "
.
join
(
ordered_words
))
filtered_resps
.
append
(
filtered
)
filtered_resps
.
append
(
filtered
)
return
filtered_resps
return
filtered_resps
class
MultiChoiceRegexFilter
(
ExtendedRegexFilter
):
class
MultiChoiceRegexFilter
(
ExtendedRegexFilter
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
"""
"""
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
...
@@ -156,13 +183,13 @@ class MultiChoiceRegexFilter(ExtendedRegexFilter):
...
@@ -156,13 +183,13 @@ class MultiChoiceRegexFilter(ExtendedRegexFilter):
for
r
,
doc
in
zip
(
resps
,
docs
):
for
r
,
doc
in
zip
(
resps
,
docs
):
fallback_regexes
=
[]
fallback_regexes
=
[]
choice_to_alpha
=
{}
choice_to_alpha
=
{}
next_alpha
=
'A'
next_alpha
=
"A"
without_paren_fallback_regexes
=
[]
without_paren_fallback_regexes
=
[]
without_paren_to_target
=
{}
without_paren_to_target
=
{}
multiple_choices_regex
=
re
.
compile
(
r
"\([A-Z]\)([^\n^(]*)"
)
multiple_choices_regex
=
re
.
compile
(
r
"\([A-Z]\)([^\n^(]*)"
)
match
=
multiple_choices_regex
.
findall
(
doc
[
'
input
'
])
match
=
multiple_choices_regex
.
findall
(
doc
[
"
input
"
])
for
m
in
match
:
for
m
in
match
:
m
=
self
.
filter_ignores
(
m
.
strip
())
m
=
self
.
filter_ignores
(
m
.
strip
())
fallback_regexes
.
append
(
f
"
{
re
.
escape
(
m
)
}
"
)
fallback_regexes
.
append
(
f
"
{
re
.
escape
(
m
)
}
"
)
...
@@ -172,17 +199,23 @@ class MultiChoiceRegexFilter(ExtendedRegexFilter):
...
@@ -172,17 +199,23 @@ class MultiChoiceRegexFilter(ExtendedRegexFilter):
without_paren_to_target
[
next_alpha
]
=
f
"(
{
next_alpha
}
)"
without_paren_to_target
[
next_alpha
]
=
f
"(
{
next_alpha
}
)"
next_alpha
=
chr
(
ord
(
next_alpha
)
+
1
)
next_alpha
=
chr
(
ord
(
next_alpha
)
+
1
)
fallback_regex
=
re
.
compile
(
'|'
.
join
(
fallback_regexes
))
fallback_regex
=
re
.
compile
(
"|"
.
join
(
fallback_regexes
))
without_paren_fallback_regex
=
'|'
.
join
(
without_paren_fallback_regexes
)
without_paren_fallback_regex
=
"|"
.
join
(
without_paren_fallback_regexes
)
without_paren_fallback_regex
=
re
.
compile
(
f
":[\s]*(
{
without_paren_fallback_regex
}
)"
)
without_paren_fallback_regex
=
re
.
compile
(
f
":[\s]*(
{
without_paren_fallback_regex
}
)"
)
filtered
=
[]
filtered
=
[]
for
resp
in
r
:
for
resp
in
r
:
match
=
self
.
find_match
(
self
.
regex
,
resp
)
match
=
self
.
find_match
(
self
.
regex
,
resp
)
if
not
match
:
if
not
match
:
match
=
self
.
find_match
(
fallback_regex
,
self
.
filter_ignores
(
resp
),
choice_to_alpha
)
match
=
self
.
find_match
(
fallback_regex
,
self
.
filter_ignores
(
resp
),
choice_to_alpha
)
if
not
match
:
if
not
match
:
match
=
self
.
find_match
(
without_paren_fallback_regex
,
resp
,
without_paren_to_target
)
match
=
self
.
find_match
(
without_paren_fallback_regex
,
resp
,
without_paren_to_target
)
if
not
match
:
if
not
match
:
match
=
self
.
fallback
match
=
self
.
fallback
filtered
.
append
(
match
)
filtered
.
append
(
match
)
...
...
lm_eval/tasks/bbh/zeroshot/utils.py
View file @
d27c0c08
import
collections
import
collections
import
re
import
re
import
sys
import
sys
import
unicodedata
import
unicodedata
from
lm_eval.filters.extraction
import
Regex
Filter
,
Filter
from
lm_eval.filters.extraction
import
Filter
,
Regex
Filter
class
ExtendedRegexFilter
(
RegexFilter
):
class
ExtendedRegexFilter
(
RegexFilter
):
punct_tbl
=
dict
.
fromkeys
(
i
for
i
in
range
(
sys
.
maxunicode
)
punct_tbl
=
dict
.
fromkeys
(
if
unicodedata
.
category
(
chr
(
i
)).
startswith
(
'P'
))
i
for
i
in
range
(
sys
.
maxunicode
)
if
unicodedata
.
category
(
chr
(
i
)).
startswith
(
"P"
)
)
def
__init__
(
def
__init__
(
self
,
regex_pattern
:
str
=
r
"#### (\-?[0-9\.\,]+)"
,
group_select
=
0
,
fallback
:
str
=
"[invalid]"
,
self
,
ignore_case
=
False
,
ignore_punctuation
=
False
,
regexes_to_ignore
=
None
,
regex_pattern
:
str
=
r
"#### (\-?[0-9\.\,]+)"
,
group_select
=
0
,
fallback
:
str
=
"[invalid]"
,
ignore_case
=
False
,
ignore_punctuation
=
False
,
regexes_to_ignore
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
(
regex_pattern
,
group_select
,
fallback
)
super
().
__init__
(
regex_pattern
,
group_select
,
fallback
)
self
.
ignore_case
=
ignore_case
self
.
ignore_case
=
ignore_case
...
@@ -47,8 +52,13 @@ class ExtendedRegexFilter(RegexFilter):
...
@@ -47,8 +52,13 @@ class ExtendedRegexFilter(RegexFilter):
class
MapRegexFilter
(
ExtendedRegexFilter
):
class
MapRegexFilter
(
ExtendedRegexFilter
):
def
__init__
(
def
__init__
(
self
,
regex_pattern_to_value
:
dict
=
{},
group_select
=
0
,
fallback
:
str
=
"[invalid]"
,
self
,
ignore_case
=
False
,
ignore_punctuation
=
False
,
regexes_to_ignore
=
None
,
regex_pattern_to_value
:
dict
=
{},
group_select
=
0
,
fallback
:
str
=
"[invalid]"
,
ignore_case
=
False
,
ignore_punctuation
=
False
,
regexes_to_ignore
=
None
,
)
->
None
:
)
->
None
:
"""
"""
regex_pattern_to_value: Match the regex pattern and change the result into the value
regex_pattern_to_value: Match the regex pattern and change the result into the value
...
@@ -57,8 +67,17 @@ class MapRegexFilter(ExtendedRegexFilter):
...
@@ -57,8 +67,17 @@ class MapRegexFilter(ExtendedRegexFilter):
ignore_punctuation: Remove the punctuation before matching with the given regex
ignore_punctuation: Remove the punctuation before matching with the given regex
regexes_to_ignore: Remove these regexes before matching with the given regex
regexes_to_ignore: Remove these regexes before matching with the given regex
"""
"""
super
().
__init__
(
'|'
.
join
(
list
(
regex_pattern_to_value
.
keys
())),
group_select
,
fallback
,
ignore_case
,
ignore_punctuation
,
regexes_to_ignore
)
super
().
__init__
(
self
.
regex_to_value
=
{
re
.
compile
(
r
):
v
for
r
,
v
in
regex_pattern_to_value
.
items
()}
"|"
.
join
(
list
(
regex_pattern_to_value
.
keys
())),
group_select
,
fallback
,
ignore_case
,
ignore_punctuation
,
regexes_to_ignore
,
)
self
.
regex_to_value
=
{
re
.
compile
(
r
):
v
for
r
,
v
in
regex_pattern_to_value
.
items
()
}
def
apply
(
self
,
resps
,
docs
):
def
apply
(
self
,
resps
,
docs
):
filtered_resps
=
[]
filtered_resps
=
[]
...
@@ -66,10 +85,15 @@ class MapRegexFilter(ExtendedRegexFilter):
...
@@ -66,10 +85,15 @@ class MapRegexFilter(ExtendedRegexFilter):
for
r
in
resps
:
for
r
in
resps
:
filtered
=
[]
filtered
=
[]
for
resp
in
r
:
for
resp
in
r
:
whole_match_considering_group_select
=
self
.
find_match
(
self
.
regex
,
self
.
filter_ignores
(
resp
))
whole_match_considering_group_select
=
self
.
find_match
(
self
.
regex
,
self
.
filter_ignores
(
resp
)
)
if
whole_match_considering_group_select
:
if
whole_match_considering_group_select
:
for
regex
,
mapped_value
in
self
.
regex_to_value
.
items
():
for
regex
,
mapped_value
in
self
.
regex_to_value
.
items
():
match
=
self
.
find_match
(
regex
,
self
.
filter_ignores
(
whole_match_considering_group_select
))
match
=
self
.
find_match
(
regex
,
self
.
filter_ignores
(
whole_match_considering_group_select
),
)
if
match
:
if
match
:
match
=
mapped_value
match
=
mapped_value
break
break
...
@@ -91,9 +115,11 @@ class NumberParseRegexFilter(ExtendedRegexFilter):
...
@@ -91,9 +115,11 @@ class NumberParseRegexFilter(ExtendedRegexFilter):
filtered_resps
=
[]
filtered_resps
=
[]
import
regex
import
regex
from
word2number
import
w2n
from
word2number
import
w2n
# https://www.reddit.com/r/regex/comments/11a38uk/parsing_numbers_written_out_as_english_words
# https://www.reddit.com/r/regex/comments/11a38uk/parsing_numbers_written_out_as_english_words
english_number_regex
=
regex
.
compile
(
english_number_regex
=
regex
.
compile
(
"((?:(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?:|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion)(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?:|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion|[^\S
\r\n
]|,|and|&)+)?(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion))"
)
"((?:(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?:|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion)(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?:|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion|[^\S
\r\n
]|,|and|&)+)?(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion))"
)
for
r
in
resps
:
for
r
in
resps
:
filtered
=
[]
filtered
=
[]
...
@@ -118,21 +144,22 @@ class WordSortFilter(Filter):
...
@@ -118,21 +144,22 @@ class WordSortFilter(Filter):
filtered_resps
=
[]
filtered_resps
=
[]
for
r
,
doc
in
zip
(
resps
,
docs
):
for
r
,
doc
in
zip
(
resps
,
docs
):
words
=
doc
[
'
input
'
].
split
(
"List:"
)[
1
].
strip
().
split
()
words
=
doc
[
"
input
"
].
split
(
"List:"
)[
1
].
strip
().
split
()
regex
=
re
.
compile
(
'|'
.
join
([
f
"
\\
b
{
w
}
\\
b"
for
w
in
words
]))
regex
=
re
.
compile
(
"|"
.
join
([
f
"
\\
b
{
w
}
\\
b"
for
w
in
words
]))
filtered
=
[]
filtered
=
[]
for
resp
in
r
:
for
resp
in
r
:
match
=
regex
.
findall
(
resp
)
match
=
regex
.
findall
(
resp
)
match
.
reverse
()
match
.
reverse
()
ordered_words
=
reversed
(
collections
.
OrderedDict
(
zip
(
match
,
[
None
]
*
len
(
match
))))
ordered_words
=
reversed
(
filtered
.
append
(
' '
.
join
(
ordered_words
))
collections
.
OrderedDict
(
zip
(
match
,
[
None
]
*
len
(
match
)))
)
filtered
.
append
(
" "
.
join
(
ordered_words
))
filtered_resps
.
append
(
filtered
)
filtered_resps
.
append
(
filtered
)
return
filtered_resps
return
filtered_resps
class
MultiChoiceRegexFilter
(
ExtendedRegexFilter
):
class
MultiChoiceRegexFilter
(
ExtendedRegexFilter
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
"""
"""
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
...
@@ -156,13 +183,13 @@ class MultiChoiceRegexFilter(ExtendedRegexFilter):
...
@@ -156,13 +183,13 @@ class MultiChoiceRegexFilter(ExtendedRegexFilter):
for
r
,
doc
in
zip
(
resps
,
docs
):
for
r
,
doc
in
zip
(
resps
,
docs
):
fallback_regexes
=
[]
fallback_regexes
=
[]
choice_to_alpha
=
{}
choice_to_alpha
=
{}
next_alpha
=
'A'
next_alpha
=
"A"
without_paren_fallback_regexes
=
[]
without_paren_fallback_regexes
=
[]
without_paren_to_target
=
{}
without_paren_to_target
=
{}
multiple_choices_regex
=
re
.
compile
(
r
"\([A-Z]\)([^\n^(]*)"
)
multiple_choices_regex
=
re
.
compile
(
r
"\([A-Z]\)([^\n^(]*)"
)
match
=
multiple_choices_regex
.
findall
(
doc
[
'
input
'
])
match
=
multiple_choices_regex
.
findall
(
doc
[
"
input
"
])
for
m
in
match
:
for
m
in
match
:
m
=
self
.
filter_ignores
(
m
.
strip
())
m
=
self
.
filter_ignores
(
m
.
strip
())
fallback_regexes
.
append
(
f
"
{
re
.
escape
(
m
)
}
"
)
fallback_regexes
.
append
(
f
"
{
re
.
escape
(
m
)
}
"
)
...
@@ -172,17 +199,23 @@ class MultiChoiceRegexFilter(ExtendedRegexFilter):
...
@@ -172,17 +199,23 @@ class MultiChoiceRegexFilter(ExtendedRegexFilter):
without_paren_to_target
[
next_alpha
]
=
f
"(
{
next_alpha
}
)"
without_paren_to_target
[
next_alpha
]
=
f
"(
{
next_alpha
}
)"
next_alpha
=
chr
(
ord
(
next_alpha
)
+
1
)
next_alpha
=
chr
(
ord
(
next_alpha
)
+
1
)
fallback_regex
=
re
.
compile
(
'|'
.
join
(
fallback_regexes
))
fallback_regex
=
re
.
compile
(
"|"
.
join
(
fallback_regexes
))
without_paren_fallback_regex
=
'|'
.
join
(
without_paren_fallback_regexes
)
without_paren_fallback_regex
=
"|"
.
join
(
without_paren_fallback_regexes
)
without_paren_fallback_regex
=
re
.
compile
(
f
":[\s]*(
{
without_paren_fallback_regex
}
)"
)
without_paren_fallback_regex
=
re
.
compile
(
f
":[\s]*(
{
without_paren_fallback_regex
}
)"
)
filtered
=
[]
filtered
=
[]
for
resp
in
r
:
for
resp
in
r
:
match
=
self
.
find_match
(
self
.
regex
,
resp
)
match
=
self
.
find_match
(
self
.
regex
,
resp
)
if
not
match
:
if
not
match
:
match
=
self
.
find_match
(
fallback_regex
,
self
.
filter_ignores
(
resp
),
choice_to_alpha
)
match
=
self
.
find_match
(
fallback_regex
,
self
.
filter_ignores
(
resp
),
choice_to_alpha
)
if
not
match
:
if
not
match
:
match
=
self
.
find_match
(
without_paren_fallback_regex
,
resp
,
without_paren_to_target
)
match
=
self
.
find_match
(
without_paren_fallback_regex
,
resp
,
without_paren_to_target
)
if
not
match
:
if
not
match
:
match
=
self
.
fallback
match
=
self
.
fallback
filtered
.
append
(
match
)
filtered
.
append
(
match
)
...
...
lm_eval/tasks/belebele/_generate_configs.py
View file @
d27c0c08
"""
"""
Take in a YAML, and output all other splits with this YAML
Take in a YAML, and output all other splits with this YAML
"""
"""
import
os
import
yaml
import
argparse
import
argparse
import
request
s
import
o
s
import
requests
import
yaml
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
lm_eval.utils
import
logging
from
lm_eval.utils
import
logging
API_URL
=
"https://datasets-server.huggingface.co/splits?dataset=facebook/belebele"
API_URL
=
"https://datasets-server.huggingface.co/splits?dataset=facebook/belebele"
...
@@ -39,6 +40,7 @@ if __name__ == "__main__":
...
@@ -39,6 +40,7 @@ if __name__ == "__main__":
def
query
():
def
query
():
response
=
requests
.
get
(
API_URL
)
response
=
requests
.
get
(
API_URL
)
return
response
.
json
()[
"splits"
]
return
response
.
json
()[
"splits"
]
print
(
query
())
print
(
query
())
languages
=
[
split
[
"split"
]
for
split
in
query
()]
languages
=
[
split
[
"split"
]
for
split
in
query
()]
...
@@ -49,7 +51,7 @@ if __name__ == "__main__":
...
@@ -49,7 +51,7 @@ if __name__ == "__main__":
if
args
.
task_prefix
!=
""
if
args
.
task_prefix
!=
""
else
f
"belebele_
{
lang
}
"
,
else
f
"belebele_
{
lang
}
"
,
"test_split"
:
lang
,
"test_split"
:
lang
,
"fewshot_split"
:
lang
,
"fewshot_split"
:
lang
,
}
}
file_save_path
=
args
.
save_prefix_path
+
f
"_
{
lang
}
.yaml"
file_save_path
=
args
.
save_prefix_path
+
f
"_
{
lang
}
.yaml"
...
...
lm_eval/tasks/bigbench/generate_tasks.py
View file @
d27c0c08
import
os
import
os
import
yaml
import
yaml
all_subtasks
=
[
all_subtasks
=
[
"abstract_narrative_understanding"
,
"abstract_narrative_understanding"
,
"anachronisms"
,
"anachronisms"
,
...
...
lm_eval/tasks/bigbench/push_bigbench_dataset.py
View file @
d27c0c08
...
@@ -8,10 +8,9 @@ Requires the installation of
...
@@ -8,10 +8,9 @@ Requires the installation of
`pip install "bigbench @ https://storage.googleapis.com/public_research_data/bigbench/bigbench-0.0.1.tar.gz"`
`pip install "bigbench @ https://storage.googleapis.com/public_research_data/bigbench/bigbench-0.0.1.tar.gz"`
and is included so that the bigbench dependency can be avoided.
and is included so that the bigbench dependency can be avoided.
"""
"""
from
tqdm
import
tqdm
import
datasets
import
bigbench.api.util
as
bb_utils
import
bigbench.api.util
as
bb_utils
import
datasets
from
tqdm
import
tqdm
all_task_names
=
bb_utils
.
get_all_json_task_names
()
all_task_names
=
bb_utils
.
get_all_json_task_names
()
...
...
lm_eval/tasks/blimp/generate_configs.py
View file @
d27c0c08
import
yaml
import
yaml
all_subtasks
=
[
all_subtasks
=
[
"adjunct_island"
,
"adjunct_island"
,
"anaphor_gender_agreement"
,
"anaphor_gender_agreement"
,
...
...
lm_eval/tasks/ceval/_generate_configs.py
View file @
d27c0c08
"""
"""
Take in a YAML, and output all other splits with this YAML
Take in a YAML, and output all other splits with this YAML
"""
"""
import
os
import
yaml
import
argparse
import
argparse
import
os
import
yaml
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
lm_eval.logger
import
eval_logger
from
lm_eval.logger
import
eval_logger
SUBJECTS
=
{
SUBJECTS
=
{
"computer_network"
:
"计算机网络"
,
"computer_network"
:
"计算机网络"
,
"operating_system"
:
"操作系统"
,
"operating_system"
:
"操作系统"
,
...
...
lm_eval/tasks/cmmlu/_generate_configs.py
View file @
d27c0c08
"""
"""
Take in a YAML, and output all other splits with this YAML
Take in a YAML, and output all other splits with this YAML
"""
"""
import
os
import
yaml
import
argparse
import
argparse
import
os
import
yaml
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
lm_eval.logger
import
eval_logger
from
lm_eval.logger
import
eval_logger
SUBJECTS
=
{
SUBJECTS
=
{
"agronomy"
:
"农学"
,
"agronomy"
:
"农学"
,
"anatomy"
:
"解剖学"
,
"anatomy"
:
"解剖学"
,
...
...
lm_eval/tasks/code_x_glue/code-text/bleu.py
View file @
d27c0c08
#!/usr/bin/python
#!/usr/bin/python
import
math
import
re
import
re
import
sys
import
sys
import
math
import
xml.sax.saxutils
import
xml.sax.saxutils
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Pattern
,
Tuple
,
Union
from
typing
import
List
,
Pattern
,
Tuple
,
Union
,
Dict
,
Any
,
Optional
"""
"""
This script was adapted from the original version by hieuhoang1972 which is part of MOSES.
This script was adapted from the original version by hieuhoang1972 which is part of MOSES.
...
@@ -60,7 +60,7 @@ def normalize(s):
...
@@ -60,7 +60,7 @@ def normalize(s):
# Added to bypass NIST-style pre-processing of hyp and ref files -- wade
# Added to bypass NIST-style pre-processing of hyp and ref files -- wade
if
nonorm
:
if
nonorm
:
return
s
.
split
()
return
s
.
split
()
if
type
(
s
)
is
not
str
:
if
not
isinstance
(
s
,
str
)
:
s
=
" "
.
join
(
s
)
s
=
" "
.
join
(
s
)
# language-independent part:
# language-independent part:
for
pattern
,
replace
in
normalize1
:
for
pattern
,
replace
in
normalize1
:
...
...
lm_eval/tasks/csatqa/_generate_configs.py
View file @
d27c0c08
"""
"""
Take in a YAML, and output all other splits with this YAML
Take in a YAML, and output all other splits with this YAML
"""
"""
import
os
import
yaml
import
argparse
import
argparse
import
os
import
yaml
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
lm_eval.logger
import
eval_logger
from
lm_eval.logger
import
eval_logger
SUBSETS
=
[
"WR"
,
"GR"
,
"RCS"
,
"RCSS"
,
"RCH"
,
"LI"
]
SUBSETS
=
[
"WR"
,
"GR"
,
"RCS"
,
"RCSS"
,
"RCH"
,
"LI"
]
...
...
lm_eval/tasks/drop/utils.py
View file @
d27c0c08
...
@@ -4,6 +4,7 @@ import string
...
@@ -4,6 +4,7 @@ import string
import
numpy
as
np
import
numpy
as
np
from
scipy.optimize
import
linear_sum_assignment
from
scipy.optimize
import
linear_sum_assignment
_ARTICLES
=
re
.
compile
(
r
"\b(a|an|the)\b"
,
re
.
UNICODE
)
_ARTICLES
=
re
.
compile
(
r
"\b(a|an|the)\b"
,
re
.
UNICODE
)
...
...
lm_eval/tasks/gpqa/n_shot/_generate_configs.py
View file @
d27c0c08
import
yaml
import
yaml
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -22,5 +21,6 @@ def main() -> None:
...
@@ -22,5 +21,6 @@ def main() -> None:
except
FileExistsError
:
except
FileExistsError
:
pass
pass
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
lm_eval/tasks/gpqa/n_shot/utils.py
View file @
d27c0c08
import
datasets
import
re
import
random
import
random
import
re
import
datasets
def
preprocess
(
text
):
def
preprocess
(
text
):
if
text
is
None
:
if
text
is
None
:
...
@@ -11,8 +13,10 @@ def preprocess(text):
...
@@ -11,8 +13,10 @@ def preprocess(text):
text
=
text
.
replace
(
" "
,
" "
)
text
=
text
.
replace
(
" "
,
" "
)
return
text
return
text
rng
=
random
.
Random
(
42
)
rng
=
random
.
Random
(
42
)
def
process_docs
(
dataset
:
datasets
.
Dataset
)
->
datasets
.
Dataset
:
def
process_docs
(
dataset
:
datasets
.
Dataset
)
->
datasets
.
Dataset
:
def
_process_doc
(
doc
):
def
_process_doc
(
doc
):
choices
=
[
choices
=
[
...
@@ -30,7 +34,7 @@ def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
...
@@ -30,7 +34,7 @@ def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
"choice2"
:
choices
[
1
],
"choice2"
:
choices
[
1
],
"choice3"
:
choices
[
2
],
"choice3"
:
choices
[
2
],
"choice4"
:
choices
[
3
],
"choice4"
:
choices
[
3
],
"answer"
:
f
"(
{
chr
(
65
+
correct_answer_index
)
}
)"
"answer"
:
f
"(
{
chr
(
65
+
correct_answer_index
)
}
)"
,
}
}
return
out_doc
return
out_doc
...
...
lm_eval/tasks/gpqa/zeroshot/_generate_configs.py
View file @
d27c0c08
import
yaml
import
yaml
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -22,5 +21,6 @@ def main() -> None:
...
@@ -22,5 +21,6 @@ def main() -> None:
except
FileExistsError
:
except
FileExistsError
:
pass
pass
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
lm_eval/tasks/gpqa/zeroshot/utils.py
View file @
d27c0c08
import
datasets
import
re
import
random
import
random
import
re
import
datasets
def
preprocess
(
text
):
def
preprocess
(
text
):
if
text
is
None
:
if
text
is
None
:
...
@@ -29,7 +31,7 @@ def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
...
@@ -29,7 +31,7 @@ def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
"choice2"
:
choices
[
1
],
"choice2"
:
choices
[
1
],
"choice3"
:
choices
[
2
],
"choice3"
:
choices
[
2
],
"choice4"
:
choices
[
3
],
"choice4"
:
choices
[
3
],
"answer"
:
f
"(
{
chr
(
65
+
correct_answer_index
)
}
)"
"answer"
:
f
"(
{
chr
(
65
+
correct_answer_index
)
}
)"
,
}
}
return
out_doc
return
out_doc
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment