Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
f275301a
Commit
f275301a
authored
Apr 23, 2023
by
haileyschoelkopf
Committed by
Hailey Schoelkopf
Apr 24, 2023
Browse files
make tasks and models registered by decorators
parent
e7c18e53
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
432 additions
and
317 deletions
+432
-317
lm_eval/api/model.py
lm_eval/api/model.py
+23
-0
lm_eval/api/task.py
lm_eval/api/task.py
+82
-1
lm_eval/evaluator.py
lm_eval/evaluator.py
+2
-2
lm_eval/models/__init__.py
lm_eval/models/__init__.py
+11
-8
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+3
-1
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+294
-294
lm_eval/tasks/arc.py
lm_eval/tasks/arc.py
+3
-2
lm_eval/tasks/gsm8k.py
lm_eval/tasks/gsm8k.py
+2
-1
lm_eval/tasks/lambada.py
lm_eval/tasks/lambada.py
+3
-2
lm_eval/tasks/wikitext.py
lm_eval/tasks/wikitext.py
+2
-2
main.py
main.py
+7
-4
No files found.
lm_eval/api/model.py
View file @
f275301a
...
...
@@ -2,6 +2,29 @@ import abc
from
lm_eval
import
utils
MODEL_REGISTRY
=
{}
def
register_model
(
name
):
# TODO: should fairseq/elk be cited for this design pattern?
def
decorate
(
cls
):
assert
(
issubclass
(
cls
,
LM
)
),
f
"Model '
{
name
}
' (
{
cls
.
__name__
}
) must extend LM class"
assert
(
name
not
in
MODEL_REGISTRY
),
f
"Model named '
{
name
}
' conflicts with existing model!"
MODEL_REGISTRY
[
name
]
=
cls
return
cls
return
decorate
def
get_model
(
model_name
):
return
MODEL_REGISTRY
[
model_name
]
class
LM
(
abc
.
ABC
):
def
__init__
(
self
):
...
...
lm_eval/api/task.py
View file @
f275301a
...
...
@@ -9,6 +9,8 @@ import itertools
import
datasets
import
numpy
as
np
from
typing
import
List
,
Union
from
lm_eval.api
import
METRIC_REGISTRY
,
AGGREGATION_REGISTRY
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.metrics
import
mean
,
weighted_perplexity
,
weighted_mean
,
bits_per_byte
...
...
@@ -31,7 +33,7 @@ class TaskConfig(dict):
# TODO: add this as more jinja2 appended to start of jinja2 templates. Should allow users to set vars
# s.t. they can define e.g. {% set question = query %} to map dataset columns to "canonical" names in prompts.
template_
var
s
:
str
=
None
template_
aliase
s
:
str
=
None
doc_to_text
:
str
=
None
doc_to_target
:
str
=
None
...
...
@@ -609,3 +611,82 @@ class PerplexityTask(Task, abc.ABC):
def
count_words
(
cls
,
doc
):
"""Downstream tasks with custom word boundaries should override this!"""
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
# TODO: confirm we want this to go in this file
TASK_REGISTRY
=
{}
ALL_TASKS
=
[]
def
register_task
(
name
):
def
decorate
(
cls
):
assert
(
issubclass
(
cls
,
Task
)
),
f
"Task '
{
name
}
' (
{
cls
.
__name__
}
) must extend Task class"
assert
(
name
not
in
TASK_REGISTRY
),
f
"Task named '
{
name
}
' conflicts with existing task!"
TASK_REGISTRY
[
name
]
=
cls
ALL_TASKS
=
sorted
(
list
(
TASK_REGISTRY
))
# TODO: this doesn't seem to import right.
return
cls
return
decorate
##### Task registry utils and setup.
# ALL_TASKS = sorted(list(TASK_REGISTRY))
def
get_task
(
task_name
):
try
:
return
TASK_REGISTRY
[
task_name
]
except
KeyError
:
print
(
"Available tasks:"
)
pprint
(
TASK_REGISTRY
)
raise
KeyError
(
f
"Missing task
{
task_name
}
"
)
def
get_task_name_from_object
(
task_object
):
for
name
,
class_
in
TASK_REGISTRY
.
items
():
if
class_
is
task_object
:
return
name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return
(
task_object
.
EVAL_HARNESS_NAME
if
hasattr
(
task_object
,
"EVAL_HARNESS_NAME"
)
else
type
(
task_object
).
__name__
)
def
get_task_name_from_config
(
task_config
):
return
"configurable_{dataset_path}_{dataset_name}"
.
format
(
**
task_config
)
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
dict
,
Task
]],
num_fewshot
=
None
):
# TODO: pass num_fewshot and other cmdline overrides in a better way
task_name_dict
=
{
task_name
:
get_task
(
task_name
)(
config
=
{
"num_fewshot"
:
num_fewshot
if
num_fewshot
else
0
,
"task_name"
:
task_name
})
for
task_name
in
task_name_list
if
isinstance
(
task_name
,
str
)
}
task_name_from_config_dict
=
{
get_task_name_from_config
(
task_config
):
ConfigurableTask
(
config
=
task_config
)
for
task_config
in
task_name_list
if
isinstance
(
task_config
,
dict
)
}
task_name_from_object_dict
=
{
get_task_name_from_object
(
task_object
):
task_object
for
task_object
in
task_name_list
if
isinstance
(
task_object
,
Task
)
}
assert
set
(
task_name_dict
.
keys
()).
isdisjoint
(
set
(
task_name_from_object_dict
.
keys
()))
return
{
**
task_name_dict
,
**
task_name_from_config_dict
,
**
task_name_from_object_dict
,
}
\ No newline at end of file
lm_eval/evaluator.py
View file @
f275301a
...
...
@@ -58,14 +58,14 @@ def simple_evaluate(
if
isinstance
(
model
,
str
):
if
model_args
is
None
:
model_args
=
""
lm
=
lm_eval
.
model
s
.
get_model
(
model
).
create_from_arg_string
(
lm
=
lm_eval
.
api
.
model
.
get_model
(
model
).
create_from_arg_string
(
model_args
,
{
"batch_size"
:
batch_size
,
"device"
:
device
}
)
else
:
assert
isinstance
(
model
,
lm_eval
.
api
.
model
.
LM
)
lm
=
model
task_dict
=
lm_eval
.
task
s
.
get_task_dict
(
tasks
,
num_fewshot
=
num_fewshot
)
task_dict
=
lm_eval
.
api
.
task
.
get_task_dict
(
tasks
,
num_fewshot
=
num_fewshot
)
if
check_integrity
:
run_task_tests
(
task_list
=
tasks
)
...
...
lm_eval/models/__init__.py
View file @
f275301a
from
lm_eval.api.model
import
LM
,
MODEL_REGISTRY
from
.
import
gpt2
from
.
import
gpt3
from
.
import
textsynth
from
.
import
dummy
MODEL_REGISTRY
=
{
"hf-causal"
:
gpt2
.
HFLM
,
"openai"
:
gpt3
.
GPT3LM
,
"textsynth"
:
textsynth
.
TextSynthLM
,
"dummy"
:
dummy
.
DummyLM
,
}
# MODEL_REGISTRY = {}
# MODEL_REGISTRY = {
# "hf-causal": gpt2.HFLM,
# "openai": gpt3.GPT3LM,
# "textsynth": textsynth.TextSynthLM,
# "dummy": dummy.DummyLM,
# }
def
get_model
(
model_name
):
return
MODEL_REGISTRY
[
model_name
]
#
def get_model(model_name):
#
return MODEL_REGISTRY[model_name]
lm_eval/models/gpt2.py
View file @
f275301a
...
...
@@ -6,9 +6,11 @@ from tqdm import tqdm
import
torch.nn.functional
as
F
from
lm_eval
import
utils
from
lm_eval.api.model
import
LM
from
lm_eval.api.model
import
LM
,
register_model
# from lm_eval.models import register_model
@
register_model
(
"hf-causal"
)
class
HFLM
(
LM
):
def
__init__
(
self
,
...
...
lm_eval/tasks/__init__.py
View file @
f275301a
This diff is collapsed.
Click to expand it.
lm_eval/tasks/arc.py
View file @
f275301a
...
...
@@ -12,7 +12,7 @@ a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questi
Homepage: https://allenai.org/data/arc
"""
from
lm_eval.api.task
import
MultipleChoiceTask
from
lm_eval.api.task
import
MultipleChoiceTask
,
register_task
from
lm_eval.prompts
import
get_prompt
from
lm_eval
import
utils
...
...
@@ -28,7 +28,7 @@ _CITATION = """
}
"""
@
register_task
(
"arc_easy"
)
class
ARCEasy
(
MultipleChoiceTask
):
VERSION
=
"2.0"
DATASET_PATH
=
"ai2_arc"
...
...
@@ -80,6 +80,7 @@ class ARCEasy(MultipleChoiceTask):
return
doc
[
"query"
]
@
register_task
(
"arc_challenge"
)
class
ARCChallenge
(
ARCEasy
):
DATASET_PATH
=
"ai2_arc"
DATASET_NAME
=
"ARC-Challenge"
lm_eval/tasks/gsm8k.py
View file @
f275301a
...
...
@@ -17,7 +17,7 @@ model's sample/generation function.
Homepage: https://github.com/openai/grade-school-math
"""
import
re
from
lm_eval.api.task
import
Task
from
lm_eval.api.task
import
Task
,
register_task
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.metrics
import
mean
...
...
@@ -41,6 +41,7 @@ ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS
=
"[invalid]"
@
register_task
(
"gsm8k"
)
class
GradeSchoolMath8K
(
Task
):
VERSION
=
0
DATASET_PATH
=
"gsm8k"
...
...
lm_eval/tasks/lambada.py
View file @
f275301a
...
...
@@ -12,7 +12,7 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
"""
from
lm_eval.api.task
import
Task
from
lm_eval.api.task
import
Task
,
register_task
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.metrics
import
mean
,
perplexity
...
...
@@ -75,6 +75,7 @@ class LambadaBase(Task):
return
{
"ppl"
:
False
,
"acc"
:
True
}
@
register_task
(
"lambada_standard"
)
class
LambadaStandard
(
LambadaBase
):
"""The LAMBADA task using the standard original LAMBADA dataset."""
...
...
@@ -90,7 +91,7 @@ class LambadaStandard(LambadaBase):
def
has_test_docs
(
self
):
return
True
@
register_task
(
"lambada_openai"
)
class
LambadaOpenAI
(
LambadaBase
):
"""The LAMBADA task using the LAMBADA OpenAI dataset, a modified version of the
original LAMBADA dataset created by OpenAI for evaluating their GPT-2 model.
...
...
lm_eval/tasks/wikitext.py
View file @
f275301a
...
...
@@ -10,7 +10,7 @@ NOTE: This `Task` is based on WikiText-2.
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
"""
import
re
from
lm_eval.api.task
import
PerplexityTask
from
lm_eval.api.task
import
PerplexityTask
,
register_task
_CITATION
=
"""
...
...
@@ -58,7 +58,7 @@ def wikitext_detokenizer(string):
return
string
@
register_task
(
"wikitext"
)
class
WikiText
(
PerplexityTask
):
VERSION
=
"2.0"
DATASET_PATH
=
"EleutherAI/wikitext_document_level"
...
...
main.py
View file @
f275301a
...
...
@@ -5,14 +5,17 @@ import fnmatch
import
yaml
from
lm_eval
import
tasks
,
evaluator
from
lm_eval.api.task
import
ConfigurableTask
# import lm_eval.api.task
from
lm_eval.api.task
import
ConfigurableTask
,
TASK_REGISTRY
logging
.
getLogger
(
"openai"
).
setLevel
(
logging
.
WARNING
)
ALL_TASKS
=
sorted
(
list
(
TASK_REGISTRY
))
class
MultiChoice
:
def
__init__
(
self
,
choices
):
self
.
choices
=
choices
print
(
f
"
{
ALL_TASKS
}
is this"
)
# Simple wildcard support (linux filename patterns)
def
__contains__
(
self
,
values
):
...
...
@@ -31,7 +34,7 @@ def parse_args():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model"
,
required
=
True
)
parser
.
add_argument
(
"--model_args"
,
default
=
""
)
parser
.
add_argument
(
"--tasks"
,
default
=
None
,
choices
=
MultiChoice
(
tasks
.
ALL_TASKS
))
parser
.
add_argument
(
"--tasks"
,
default
=
None
,
choices
=
MultiChoice
(
ALL_TASKS
))
parser
.
add_argument
(
"--config"
,
default
=
None
)
parser
.
add_argument
(
"--provide_description"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--num_fewshot"
,
type
=
int
,
default
=
0
)
...
...
@@ -80,9 +83,9 @@ def main():
task_names
.
append
(
config
)
else
:
task_names
=
tasks
.
ALL_TASKS
task_names
=
ALL_TASKS
else
:
task_names
=
pattern_match
(
args
.
tasks
.
split
(
","
),
tasks
.
ALL_TASKS
)
task_names
=
pattern_match
(
args
.
tasks
.
split
(
","
),
ALL_TASKS
)
print
(
f
"Selected Tasks:
{
task_names
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment