Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
f6b76f5d
Unverified
Commit
f6b76f5d
authored
May 09, 2023
by
Hailey Schoelkopf
Committed by
GitHub
May 09, 2023
Browse files
Merge pull request #486 from EleutherAI/yaml-parameterize
[Refactor] Add decorator for registering YAMLs as tasks
parents
95642aa6
f1beac00
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
91 additions
and
11 deletions
+91
-11
lm_eval/api/task.py
lm_eval/api/task.py
+50
-6
lm_eval/evaluator.py
lm_eval/evaluator.py
+2
-1
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+10
-1
lm_eval/tasks/yaml/arc_challenge.yaml
lm_eval/tasks/yaml/arc_challenge.yaml
+2
-0
lm_eval/tasks/yaml/arc_easy.yaml
lm_eval/tasks/yaml/arc_easy.yaml
+2
-0
lm_eval/tasks/yaml/gsm8k.yaml
lm_eval/tasks/yaml/gsm8k.yaml
+2
-0
lm_eval/tasks/yaml/lambada.yaml
lm_eval/tasks/yaml/lambada.yaml
+2
-0
lm_eval/tasks/yaml/pile_enron.yaml
lm_eval/tasks/yaml/pile_enron.yaml
+2
-0
lm_eval/tasks/yaml/sglue_cb.yaml
lm_eval/tasks/yaml/sglue_cb.yaml
+2
-0
lm_eval/utils.py
lm_eval/utils.py
+13
-0
main.py
main.py
+4
-3
No files found.
lm_eval/api/task.py
View file @
f6b76f5d
...
...
@@ -27,7 +27,8 @@ from lm_eval.api import samplers
@
dataclass
class
TaskConfig
(
dict
):
task_name
:
str
=
None
names
:
str
=
None
task_name
:
str
=
None
# TODO: deprecate this, it'll be set in __post_init__ to be names[0]
dataset_path
:
str
=
None
dataset_name
:
str
=
None
training_split
:
str
=
None
...
...
@@ -54,6 +55,8 @@ class TaskConfig(dict):
doc_to_decontamination_query
:
str
=
None
use_prompt
:
str
=
None
metadata
:
str
=
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
def
__post_init__
(
self
):
# allow user-specified aliases so that users can
# force prompt-compatibility for some prompt regardless of
...
...
@@ -61,6 +64,10 @@ class TaskConfig(dict):
self
.
doc_to_text
=
self
.
template_aliases
+
self
.
doc_to_text
self
.
doc_to_target
=
self
.
template_aliases
+
self
.
doc_to_target
# set "task_name" metadata field based on the "primary" name set
if
self
.
names
:
self
.
task_name
=
self
.
names
[
0
]
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
...
...
@@ -268,7 +275,7 @@ class Task(abc.ABC):
)
# TODO: hardcoded for now: # of runs on each input to be 2. # TODO: we should override this if doing greedy gen so users don't waste time+compute
inst
=
self
.
construct_requests
(
doc
=
doc
,
ctx
=
fewshot_ctx
,
metadata
=
(
self
.
_config
[
"task_name"
],
doc_id
,
2
))
inst
=
self
.
construct_requests
(
doc
=
doc
,
ctx
=
fewshot_ctx
,
metadata
=
(
self
.
_config
[
"task_name"
],
doc_id
,
1
))
if
not
isinstance
(
inst
,
list
):
inst
=
[
inst
]
...
...
@@ -405,12 +412,18 @@ class ConfigurableTask(Task):
VERSION
=
"2.0"
OUTPUT_TYPE
=
None
CONFIG
=
None
def
__init__
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
config
:
dict
=
None
):
self
.
_config
=
TaskConfig
(
**
config
)
# if we are a subclass that has the CONFIG class attr set, ignore whatever is passed.
self
.
_config
=
self
.
CONFIG
# else, if a config was passed as kwarg: use it
if
(
self
.
_config
is
None
)
and
config
:
self
.
_config
=
TaskConfig
(
**
config
)
if
self
.
_config
is
None
:
raise
ValueError
(
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
)
if
self
.
_config
.
output_type
is
not
None
:
self
.
OUTPUT_TYPE
=
self
.
_config
.
output_type
...
...
@@ -620,7 +633,6 @@ class ConfigurableTask(Task):
}
# TODO: set which normalization metrics should be reported, and calculate them
# TODO: add mutual info.
if
"exact_match"
in
self
.
_metric_list
.
keys
():
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
...
...
@@ -670,7 +682,7 @@ class MultipleChoiceTask(Task):
return
" "
+
doc
[
"choices"
][
doc
[
"gold"
]]
def
construct_requests
(
self
,
doc
,
ctx
,
**
kwargs
):
# TODO: add mutual info here?
return
[
Instance
(
request_type
=
"loglikelihood"
,
doc
=
doc
,
...
...
@@ -803,6 +815,38 @@ def register_task(*names):
return
decorate
def
register_yaml_task
(
yaml_path
):
# same goal as register_task() but used to register yamls
import
yaml
with
open
(
yaml_path
,
"r"
)
as
f
:
config
=
yaml
.
load
(
f
,
yaml
.
Loader
)
from
functools
import
partial
# TODO: strip whitespace from name?
# TODO: ensure num_fewshot overrides the config vals
def
decorate
(
names
,
cls
):
for
name
in
names
:
assert
(
issubclass
(
cls
,
Task
)
),
f
"Task '
{
name
}
' (
{
cls
.
__name__
}
) must extend Task class"
assert
(
name
not
in
TASK_REGISTRY
),
f
"Task named '
{
name
}
' conflicts with existing task! Please register with a non-conflicting alias instead."
TASK_REGISTRY
[
name
]
=
cls
ALL_TASKS
=
sorted
(
list
(
TASK_REGISTRY
))
# TODO: this doesn't seem to import properly.
return
cls
# we create a subclass that has subclass attr CONFIG = our yaml config, and decorate with the config's specified aliases
names
=
config
[
'names'
]
yaml_task
=
decorate
(
names
,
type
(
config
[
'names'
][
0
]
+
'ConfigurableTask'
,
(
ConfigurableTask
,),
{
'CONFIG'
:
TaskConfig
(
**
config
)})
)
##### Task registry utils and setup.
# ALL_TASKS = sorted(list(TASK_REGISTRY))
...
...
lm_eval/evaluator.py
View file @
f6b76f5d
...
...
@@ -6,7 +6,7 @@ import lm_eval.api.metrics
import
lm_eval.models
import
lm_eval.tasks
import
lm_eval.api
from
lm_eval.utils
import
positional_deprecated
,
run_task_tests
,
make_table
from
lm_eval.utils
import
positional_deprecated
,
run_task_tests
,
make_table
,
get_git_commit_hash
@
positional_deprecated
...
...
@@ -90,6 +90,7 @@ def simple_evaluate(
"limit"
:
limit
,
"bootstrap_iters"
:
bootstrap_iters
,
}
results
[
"git_hash"
]
=
get_git_commit_hash
()
return
results
...
...
lm_eval/tasks/__init__.py
View file @
f6b76f5d
# from lm_eval.api.task import register_yaml_task
import
os
from
lm_eval.api.task
import
register_yaml_task
from
.vanilla
import
*
# we want to register all yaml tasks in our .yaml folder.
yaml_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
"/"
+
"yaml"
for
yaml
in
sorted
(
os
.
listdir
(
yaml_dir
)):
yaml
=
os
.
path
.
join
(
yaml_dir
,
yaml
)
register_yaml_task
(
yaml
)
lm_eval/tasks/yaml/arc_challenge.yaml
View file @
f6b76f5d
names
:
-
arc_challenge_yaml
dataset_path
:
ai2_arc
dataset_name
:
ARC-Challenge
output_type
:
multiple_choice
...
...
lm_eval/tasks/yaml/arc_easy.yaml
View file @
f6b76f5d
names
:
-
arc_easy_yaml
dataset_path
:
ai2_arc
dataset_name
:
ARC-Easy
output_type
:
multiple_choice
...
...
lm_eval/tasks/yaml/gsm8k.yaml
View file @
f6b76f5d
names
:
-
gsm8k_yaml
dataset_path
:
gsm8k
dataset_name
:
main
training_split
:
train
...
...
lm_eval/tasks/yaml/lambada.yaml
View file @
f6b76f5d
names
:
-
lambada_openai_yaml
dataset_path
:
EleutherAI/lambada_openai
dataset_name
:
default
output_type
:
loglikelihood
...
...
lm_eval/tasks/yaml/pile_enron.yaml
View file @
f6b76f5d
names
:
-
pile_enron_yaml
dataset_path
:
EleutherAI/the_pile
dataset_name
:
enron_emails
output_type
:
loglikelihood_rolling
...
...
lm_eval/tasks/yaml/sglue_cb.yaml
View file @
f6b76f5d
names
:
-
sglue_cb_yamltest
dataset_path
:
super_glue
dataset_name
:
cb
training_split
:
train
...
...
lm_eval/utils.py
View file @
f6b76f5d
...
...
@@ -240,6 +240,19 @@ def run_task_tests(task_list: List[str]):
)
def
get_git_commit_hash
():
"""
Gets the git commit hash of your current repo (if it exists).
Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42
"""
try
:
git_hash
=
subprocess
.
check_output
([
"git"
,
"describe"
,
"--always"
]).
strip
()
git_hash
=
git_hash
.
decode
()
except
subprocess
.
CalledProcessError
:
git_hash
=
None
return
git_hash
env
=
Environment
(
loader
=
BaseLoader
,
undefined
=
StrictUndefined
)
...
...
main.py
View file @
f6b76f5d
...
...
@@ -3,15 +3,16 @@ import json
import
logging
import
fnmatch
import
yaml
import
os
from
lm_eval
import
tasks
,
evaluator
# import lm_eval.api.task
from
lm_eval
import
evaluator
,
tasks
from
lm_eval.api.task
import
ConfigurableTask
,
TASK_REGISTRY
logging
.
getLogger
(
"openai"
).
setLevel
(
logging
.
WARNING
)
os
.
environ
[
'TOKENIZERS_PARALLELISM'
]
=
'false'
ALL_TASKS
=
sorted
(
list
(
TASK_REGISTRY
))
class
MultiChoice
:
def
__init__
(
self
,
choices
):
self
.
choices
=
choices
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment