Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
748a9898
Commit
748a9898
authored
Jul 19, 2023
by
lintangsutawika
Browse files
can now process a benchmark that uses promptsource
parent
7411466c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
160 additions
and
32 deletions
+160
-32
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+75
-14
lm_eval/tasks/benchmarks/t0_eval.yaml
lm_eval/tasks/benchmarks/t0_eval.yaml
+81
-18
lm_eval/utils.py
lm_eval/utils.py
+4
-0
No files found.
lm_eval/tasks/__init__.py
View file @
748a9898
...
@@ -3,6 +3,7 @@ import yaml
...
@@ -3,6 +3,7 @@ import yaml
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
lm_eval
import
utils
from
lm_eval
import
utils
from
lm_eval
import
prompts
from
lm_eval.logger
import
eval_logger
from
lm_eval.logger
import
eval_logger
from
lm_eval.api.task
import
TaskConfig
,
Task
,
ConfigurableTask
from
lm_eval.api.task
import
TaskConfig
,
Task
,
ConfigurableTask
from
lm_eval.api.registry
import
(
from
lm_eval.api.registry
import
(
...
@@ -14,6 +15,59 @@ from lm_eval.api.registry import (
...
@@ -14,6 +15,59 @@ from lm_eval.api.registry import (
)
)
def
register_configurable_task
(
config
):
SubClass
=
type
(
config
[
"task"
]
+
"ConfigurableTask"
,
(
ConfigurableTask
,),
{
"CONFIG"
:
TaskConfig
(
**
config
)},
)
if
"task"
in
config
:
task_name
=
"{}"
.
format
(
config
[
"task"
])
register_task
(
task_name
)(
SubClass
)
if
"group"
in
config
:
if
type
(
config
[
"group"
])
==
str
:
group_name
=
[
config
[
"group"
]]
else
:
group_name
=
config
[
"group"
]
for
group
in
group_name
:
register_group
(
group
)(
SubClass
)
return
0
def
check_prompt_config
(
config
):
all_configs
=
[]
if
"use_prompt"
in
config
:
prompt_list
=
prompts
.
load_prompt_list
(
use_prompt
=
config
[
"use_prompt"
],
dataset_name
=
config
[
"dataset_path"
],
subset_name
=
config
[
"dataset_name"
],
)
for
idx
,
prompt_variation
in
enumerate
(
prompt_list
):
all_configs
.
append
(
{
**
config
,
**
{
"use_prompt"
:
prompt_variation
},
**
{
"task"
:
"_"
.
join
(
[
get_task_name_from_config
(
config
),
"promptsource"
,
str
(
idx
).
zfill
(
2
),
]
)
},
**
{
"output_type"
:
"greedy_until"
},
}
)
else
:
all_configs
.
append
(
config
)
return
all_configs
def
get_task_name_from_config
(
task_config
):
def
get_task_name_from_config
(
task_config
):
if
"dataset_name"
in
task_config
:
if
"dataset_name"
in
task_config
:
return
"{dataset_path}_{dataset_name}"
.
format
(
**
task_config
)
return
"{dataset_path}_{dataset_name}"
.
format
(
**
task_config
)
...
@@ -32,20 +86,10 @@ def include_task_folder(task_dir):
...
@@ -32,20 +86,10 @@ def include_task_folder(task_dir):
yaml_path
=
os
.
path
.
join
(
root
,
f
)
yaml_path
=
os
.
path
.
join
(
root
,
f
)
try
:
try
:
config
=
utils
.
load_yaml_config
(
yaml_path
)
config
=
utils
.
load_yaml_config
(
yaml_path
)
all_configs
=
check_prompt_config
(
config
)
for
config
in
all_configs
:
register_configurable_task
(
config
)
SubClass
=
type
(
config
[
"task"
]
+
"ConfigurableTask"
,
(
ConfigurableTask
,),
{
"CONFIG"
:
TaskConfig
(
**
config
)},
)
if
"task"
in
config
:
task_name
=
"{}"
.
format
(
config
[
"task"
])
register_task
(
task_name
)(
SubClass
)
if
"group"
in
config
:
for
group
in
config
[
"group"
]:
register_group
(
group
)(
SubClass
)
except
Exception
as
error
:
except
Exception
as
error
:
eval_logger
.
warning
(
eval_logger
.
warning
(
"Failed to load config in
\n
"
"Failed to load config in
\n
"
...
@@ -69,7 +113,24 @@ def include_benchmarks(task_dir, benchmark_dir="benchmarks"):
...
@@ -69,7 +113,24 @@ def include_benchmarks(task_dir, benchmark_dir="benchmarks"):
assert
"group"
in
yaml_config
assert
"group"
in
yaml_config
group
=
yaml_config
[
"group"
]
group
=
yaml_config
[
"group"
]
task_list
=
yaml_config
[
"task"
]
all_task_list
=
yaml_config
[
"task"
]
config_list
=
[
task
for
task
in
all_task_list
if
type
(
task
)
!=
str
]
task_list
=
[
task
for
task
in
all_task_list
if
type
(
task
)
==
str
]
for
task_config
in
config_list
:
var_configs
=
check_prompt_config
(
{
**
task_config
,
**
{
"group"
:
group
},
}
)
for
config
in
var_configs
:
register_configurable_task
(
config
)
task_names
=
utils
.
pattern_match
(
task_list
,
ALL_TASKS
)
task_names
=
utils
.
pattern_match
(
task_list
,
ALL_TASKS
)
for
task
in
task_names
:
for
task
in
task_names
:
if
task
in
TASK_REGISTRY
:
if
task
in
TASK_REGISTRY
:
...
...
lm_eval/tasks/benchmarks/t0_eval.yaml
View file @
748a9898
group
:
group
:
t0_eval
-
t0_eval
task
:
task
:
-
dataset_path
:
super_glue
# Coreference Resolution
# Coreference Resolution
-
dataset_path
:
super_glue
dataset_name
:
wsc.fixed
dataset_name
:
wsc.fixed
use_prompt
:
promptsource
use_prompt
:
promptsource:*
-
dataset_path
:
winogrande
# Coreference Resolution
training_split
:
train
validation_split
:
validation
metric_list
:
-
metric
:
exact_match
aggregation
:
mean
higher_is_better
:
true
ignore_case
:
true
ignore_punctuation
:
true
# Coreference Resolution
-
dataset_path
:
winogrande
dataset_name
:
winogrande_xl
dataset_name
:
winogrande_xl
use_prompt
:
promptsource
use_prompt
:
promptsource:*
-
dataset_path
:
super_glue
# Natural Language Inference
training_split
:
train
validation_split
:
validation
metric_list
:
-
metric
:
exact_match
aggregation
:
mean
higher_is_better
:
true
ignore_case
:
true
ignore_punctuation
:
true
# Natural Language Inference
-
dataset_path
:
super_glue
dataset_name
:
cb
dataset_name
:
cb
use_prompt
:
promptsource
use_prompt
:
promptsource:*
-
dataset_path
:
super_glue
# Natural Language Inference
training_split
:
train
validation_split
:
validation
metric_list
:
-
metric
:
exact_match
aggregation
:
mean
higher_is_better
:
true
ignore_case
:
true
ignore_punctuation
:
true
# Natural Language Inference
-
dataset_path
:
super_glue
dataset_name
:
rte
dataset_name
:
rte
use_prompt
:
promptsource
use_prompt
:
promptsource:*
-
dataset_path
:
anli
# Natural Language Inference
training_split
:
train
use_prompt
:
promptsource
validation_split
:
validation
-
dataset_path
:
super_glue
# Sentence Completion
metric_list
:
-
metric
:
exact_match
aggregation
:
mean
higher_is_better
:
true
ignore_case
:
true
ignore_punctuation
:
true
# Natural Language Inference
# - dataset_path: anli
# use_prompt: promptsource:*
# Sentence Completion
-
dataset_path
:
super_glue
dataset_name
:
copa
dataset_name
:
copa
use_prompt
:
promptsource
use_prompt
:
promptsource:*
-
dataset_path
:
hellaswag
# Natural Language Inference
training_split
:
train
use_prompt
:
promptsource
validation_split
:
validation
-
dataset_path
:
super_glue
# Word Sense Disambiguation
metric_list
:
-
metric
:
exact_match
aggregation
:
mean
higher_is_better
:
true
ignore_case
:
true
ignore_punctuation
:
true
# Natural Language Inference
-
dataset_path
:
hellaswag
use_prompt
:
promptsource:*
training_split
:
train
validation_split
:
validation
metric_list
:
-
metric
:
exact_match
aggregation
:
mean
higher_is_better
:
true
ignore_case
:
true
ignore_punctuation
:
true
# Word Sense Disambiguation
-
dataset_path
:
super_glue
dataset_name
:
wic
dataset_name
:
wic
use_prompt
:
promptsource
use_prompt
:
promptsource:*
training_split
:
train
validation_split
:
validation
metric_list
:
-
metric
:
exact_match
aggregation
:
mean
higher_is_better
:
true
ignore_case
:
true
ignore_punctuation
:
true
lm_eval/utils.py
View file @
748a9898
...
@@ -108,6 +108,10 @@ class MultiChoice:
...
@@ -108,6 +108,10 @@ class MultiChoice:
# Returns a list containing all values of the source_list that
# Returns a list containing all values of the source_list that
# match at least one of the patterns
# match at least one of the patterns
def
pattern_match
(
patterns
,
source_list
):
def
pattern_match
(
patterns
,
source_list
):
if
type
(
patterns
)
==
str
:
patterns
=
[
patterns
]
task_names
=
set
()
task_names
=
set
()
for
pattern
in
patterns
:
for
pattern
in
patterns
:
for
matching
in
fnmatch
.
filter
(
source_list
,
pattern
):
for
matching
in
fnmatch
.
filter
(
source_list
,
pattern
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment