Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
c5ed8cdc
Unverified
Commit
c5ed8cdc
authored
May 20, 2023
by
Lintang Sutawika
Committed by
GitHub
May 20, 2023
Browse files
Merge pull request #501 from EleutherAI/update-config
Update config
parents
f6b76f5d
c17e3659
Changes
47
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
127 additions
and
46 deletions
+127
-46
lm_eval/tasks/super_glue/wsc/preprocess_wsc.py
lm_eval/tasks/super_glue/wsc/preprocess_wsc.py
+17
-0
lm_eval/tasks/super_glue/wsc/t5-prompt.yaml
lm_eval/tasks/super_glue/wsc/t5-prompt.yaml
+16
-0
lm_eval/tasks/vanilla/__init__.py
lm_eval/tasks/vanilla/__init__.py
+0
-8
lm_eval/tasks/wikitext.py
lm_eval/tasks/wikitext.py
+4
-1
lm_eval/utils.py
lm_eval/utils.py
+63
-5
main.py
main.py
+26
-31
setup.py
setup.py
+1
-1
No files found.
lm_eval/tasks/super_glue/wsc/preprocess_wsc.py
0 → 100644
View file @
c5ed8cdc
import
re
def
doc_to_text
(
x
):
def
_mark_span
(
text
,
span_str
,
span_idx
,
mark
):
pattern_tmpl
=
r
"^((?:\S+\s){N})(W)"
pattern
=
re
.
sub
(
"N"
,
str
(
span_idx
),
pattern_tmpl
)
pattern
=
re
.
sub
(
"W"
,
span_str
,
pattern
)
return
re
.
sub
(
pattern
,
r
"\1{0} \2 {0}"
.
format
(
mark
),
text
)
text
=
x
[
"text"
]
text
=
_mark_span
(
text
,
x
[
"span1_text"
],
x
[
"span1_index"
],
"*"
)
# Compensate for 2 added "words" added in previous step.
span2_index
=
x
[
"span2_index"
]
+
2
*
(
x
[
"span1_index"
]
<
x
[
"span2_index"
])
text
=
_mark_span
(
text
,
x
[
"span2_text"
],
span2_index
,
"#"
)
return
text
lm_eval/tasks/super_glue/wsc/t5-prompt.yaml
0 → 100644
View file @
c5ed8cdc
group
:
-
super-glue-t5-prompt
task
:
t5-prompt
reference
:
"
From
Raffel
et.
al.
2019"
dataset_path
:
super_glue
dataset_name
:
wsc
training_split
:
train
validation_split
:
validation
doc_to_text
:
!function
"
preprocess_wsc.doc_to_text"
doc_to_target
:
"
{%
set
answer_choices
=
['False',
'True']
%}{{answer_choices[label]}}"
metric_list
:
-
metric
:
exact_match
aggregation
:
mean
higher_is_better
:
true
ignore_case
:
true
ignore_punctuation
:
true
lm_eval/tasks/vanilla/__init__.py
deleted
100644 → 0
View file @
f6b76f5d
from
.
import
arc
from
.
import
gsm8k
from
.
import
lambada
from
.
import
pile
from
.
import
wikitext
# TODO: define via __all__
\ No newline at end of file
lm_eval/tasks/
vanilla/
wikitext.py
→
lm_eval/tasks/wikitext.py
View file @
c5ed8cdc
...
...
@@ -10,8 +10,10 @@ NOTE: This `Task` is based on WikiText-2.
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
"""
import
re
from
lm_eval.api.task
import
PerplexityTask
,
register_task
from
lm_eval.api.task
import
PerplexityTask
from
lm_eval.api.register
import
register_task
,
register_group
_CITATION
=
"""
@misc{merity2016pointer,
...
...
@@ -58,6 +60,7 @@ def wikitext_detokenizer(string):
return
string
@
register_task
(
"wikitext"
)
class
WikiText
(
PerplexityTask
):
VERSION
=
"2.0"
...
...
lm_eval/utils.py
View file @
c5ed8cdc
import
os
import
pathlib
import
re
import
collections
import
functools
import
inspect
import
sys
import
yaml
import
inspect
import
pathlib
import
functools
import
subprocess
import
collections
import
importlib.util
from
typing
import
List
from
omegaconf
import
OmegaConf
...
...
@@ -146,7 +150,6 @@ class Reorderer:
return
res
def
make_table
(
result_dict
):
"""Generate table of results."""
from
pytablewriter
import
MarkdownTableWriter
,
LatexTableWriter
...
...
@@ -253,6 +256,61 @@ def get_git_commit_hash():
return
git_hash
def
import_function
(
loader
,
node
):
function_name
=
loader
.
construct_scalar
(
node
)
yaml_path
=
os
.
path
.
dirname
(
loader
.
name
)
module_name
,
function_name
=
function_name
.
split
(
"."
)
module_path
=
os
.
path
.
join
(
yaml_path
,
"{}.py"
.
format
(
module_name
))
spec
=
importlib
.
util
.
spec_from_file_location
(
module_name
,
module_path
)
module
=
importlib
.
util
.
module_from_spec
(
spec
)
spec
.
loader
.
exec_module
(
module
)
function
=
getattr
(
module
,
function_name
)
return
function
# Add the import_function constructor to the YAML loader
yaml
.
add_constructor
(
"!function"
,
import_function
)
def
load_yaml_config
(
yaml_path
):
with
open
(
yaml_path
,
"rb"
)
as
file
:
yaml_config
=
yaml
.
full_load
(
file
)
yaml_dir
=
os
.
path
.
dirname
(
yaml_path
)
if
"include"
in
yaml_config
:
include_path
=
yaml_config
[
"include"
]
del
yaml_config
[
"include"
]
if
type
(
include_path
)
==
str
:
include_path
=
[
include_path
]
# Load from the last one first
include_path
.
reverse
()
final_yaml_config
=
{}
for
path
in
include_path
:
# Assumes that path is a full path.
# If not found, assume the included yaml
# is in the same dir as the original yaml
if
not
os
.
path
.
isfile
(
path
):
path
=
os
.
path
.
join
(
yaml_dir
,
path
)
try
:
included_yaml_config
=
load_yaml_config
(
path
)
final_yaml_config
.
update
(
included_yaml_config
)
except
Exception
as
ex
:
# If failed to load, ignore
raise
ex
final_yaml_config
.
update
(
yaml_config
)
return
final_yaml_config
return
yaml_config
env
=
Environment
(
loader
=
BaseLoader
,
undefined
=
StrictUndefined
)
...
...
main.py
View file @
c5ed8cdc
import
argparse
import
os
import
json
import
logging
import
fnmatch
import
yaml
import
os
import
argparse
from
lm_eval
import
evaluator
,
tasks
from
lm_eval.api.task
import
ConfigurableTask
,
TASK_REGISTRY
from
lm_eval
import
evaluator
,
utils
from
lm_eval.tasks
import
ALL_TASKS
from
lm_eval.logger
import
eval_logger
logging
.
getLogger
(
"openai"
).
setLevel
(
logging
.
WARNING
)
os
.
environ
[
'TOKENIZERS_PARALLELISM'
]
=
'false'
ALL_TASKS
=
sorted
(
list
(
TASK_REGISTRY
))
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
class
MultiChoice
:
def
__init__
(
self
,
choices
):
self
.
choices
=
choices
print
(
f
"
{
ALL_TASKS
}
is this"
)
# Simple wildcard support (linux filename patterns)
def
__contains__
(
self
,
values
):
for
value
in
values
.
split
(
","
):
if
len
(
fnmatch
.
filter
(
self
.
choices
,
value
))
==
0
:
return
False
eval_logger
.
warning
(
"{} is not in task list."
.
format
(
value
))
# eval_logger.info(f"{choices} is this")
return
True
...
...
@@ -47,7 +44,6 @@ def parse_args():
parser
.
add_argument
(
"--decontamination_ngrams_path"
,
default
=
None
)
parser
.
add_argument
(
"--description_dict_path"
,
default
=
None
)
parser
.
add_argument
(
"--check_integrity"
,
action
=
"store_true"
)
return
parser
.
parse_args
()
...
...
@@ -65,30 +61,29 @@ def main():
args
=
parse_args
()
if
args
.
limit
:
print
(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
eval_logger
.
warning
(
" --limit SHOULD ONLY BE USED FOR TESTING."
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if
args
.
tasks
is
None
:
if
args
.
config
:
task_names
=
[]
for
config_files
in
args
.
config
.
split
(
","
):
with
open
(
config_files
,
"r"
)
as
f
:
config
=
yaml
.
load
(
f
,
yaml
.
Loader
)
if
args
.
num_fewshot
!=
0
:
config
[
"num_fewshot"
]
=
args
.
num_fewshot
if
args
.
batch_size
!=
None
:
config
[
"batch_size"
]
=
args
.
batch_size
if
args
.
tasks
is
not
None
:
if
os
.
path
.
isdir
(
args
.
tasks
):
import
glob
task_names
=
[]
yaml_path
=
os
.
path
.
join
(
args
.
tasks
,
"*.yaml"
)
for
yaml_file
in
glob
.
glob
(
yaml_path
):
config
=
utils
.
load_yaml_config
(
yaml_file
)
task_names
.
append
(
config
)
else
:
task_names
=
ALL_TASKS
else
:
task_names
=
pattern_match
(
args
.
tasks
.
split
(
","
),
ALL_TASKS
)
print
(
f
"Selected Tasks:
{
task_names
}
"
)
tasks_list
=
args
.
tasks
.
split
(
","
)
task_names
=
pattern_match
(
tasks_list
,
ALL_TASKS
)
for
task
in
[
task
for
task
in
tasks_list
if
task
not
in
task_names
]:
if
os
.
path
.
isfile
(
task
):
config
=
utils
.
load_yaml_config
(
task
)
task_names
.
append
(
config
)
eval_logger
.
info
(
f
"Selected Tasks:
{
task_names
}
"
)
results
=
evaluator
.
simple_evaluate
(
model
=
args
.
model
,
...
...
setup.py
View file @
c5ed8cdc
...
...
@@ -42,6 +42,6 @@ setuptools.setup(
extras_require
=
{
"dev"
:
[
"black"
,
"flake8"
,
"pre-commit"
,
"pytest"
,
"pytest-cov"
],
"multilingual"
:
[
"nagisa>=0.2.7"
,
"jieba>=0.42.1"
],
"sentencepiece"
:
[
"sentencepiece>=0.1.98"
,
"protobuf>=4.22.1"
]
"sentencepiece"
:
[
"sentencepiece>=0.1.98"
,
"protobuf>=4.22.1"
]
,
},
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment