Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
7a058d69
Commit
7a058d69
authored
Sep 30, 2021
by
Leo Gao
Browse files
Improve interface and upgrade results dict format
parent
b569f483
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
42 deletions
+73
-42
lm_eval/evaluator.py
lm_eval/evaluator.py
+68
-2
main.py
main.py
+5
-40
No files found.
lm_eval/evaluator.py
View file @
7a058d69
...
@@ -2,9 +2,43 @@ import collections
...
@@ -2,9 +2,43 @@ import collections
import
itertools
import
itertools
import
random
import
random
import
lm_eval.metrics
import
lm_eval.metrics
import
lm_eval.models
import
lm_eval.tasks
import
lm_eval.base
import
numpy
as
np
def
simple_evaluate
(
model
,
model_args
,
task_names
,
num_fewshot
=
0
,
batch_size
=
None
,
device
=
None
,
no_cache
=
False
,
limit
=
None
,
bootstrap_iters
=
100000
):
random
.
seed
(
1234
)
np
.
random
.
seed
(
1234
)
lm
=
lm_eval
.
models
.
get_model
(
model
).
create_from_arg_string
(
model_args
,
{
'batch_size'
:
batch_size
,
'device'
:
device
})
if
not
no_cache
:
lm
=
lm_eval
.
base
.
CachingLM
(
lm
,
'lm_cache/'
+
model
+
'_'
+
model_args
.
replace
(
'='
,
'-'
).
replace
(
','
,
'_'
).
replace
(
'/'
,
'-'
)
+
'.db'
)
task_dict
=
lm_eval
.
tasks
.
get_task_dict
(
task_names
)
results
=
evaluate
(
lm
,
task_dict
,
False
,
num_fewshot
,
limit
)
# add info about the model and few shot config
results
[
"config"
]
=
{
"model"
:
model
,
"model_args"
:
model_args
,
"num_fewshot"
:
num_fewshot
,
"batch_size"
:
batch_size
,
"device"
:
device
,
"no_cache"
:
no_cache
,
"limit"
:
limit
,
"bootstrap_iters"
:
bootstrap_iters
}
return
results
def
evaluate
(
lm
,
task_dict
,
provide_description
,
num_fewshot
,
limit
,
bootstrap_iters
=
100000
):
def
evaluate
(
lm
,
task_dict
,
provide_description
,
num_fewshot
,
limit
,
bootstrap_iters
=
100000
):
assert
not
provide_description
# not implemented. todo: implement proper description-providing system
# TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces
# TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces
task_dict_items
=
[(
name
,
task
)
for
name
,
task
in
task_dict
.
items
()
if
(
task
.
has_validation_docs
()
or
task
.
has_test_docs
())]
task_dict_items
=
[(
name
,
task
)
for
name
,
task
in
task_dict
.
items
()
if
(
task
.
has_validation_docs
()
or
task
.
has_test_docs
())]
...
@@ -100,6 +134,38 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
...
@@ -100,6 +134,38 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
results
[
task_name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
results
[
task_name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
return
{
return
{
"results"
:
results
,
"results"
:
dict
(
results
)
,
"versions"
:
versions
"versions"
:
dict
(
versions
)
}
}
def
make_table
(
result_dict
):
from
pytablewriter
import
MarkdownTableWriter
,
LatexTableWriter
md_writer
=
MarkdownTableWriter
()
latex_writer
=
LatexTableWriter
()
md_writer
.
headers
=
[
"Task"
,
"Version"
,
"Metric"
,
"Value"
,
""
,
"Stderr"
]
latex_writer
.
headers
=
[
"Task"
,
"Version"
,
"Metric"
,
"Value"
,
""
,
"Stderr"
]
values
=
[]
for
k
,
dic
in
result_dict
[
"results"
].
items
():
version
=
result_dict
[
"versions"
][
k
]
for
m
,
v
in
dic
.
items
():
if
m
.
endswith
(
"_stderr"
):
continue
if
m
+
"_stderr"
in
dic
:
se
=
dic
[
m
+
"_stderr"
]
values
.
append
([
k
,
version
,
m
,
'%.4f'
%
v
,
'±'
,
'%.4f'
%
se
])
else
:
values
.
append
([
k
,
version
,
m
,
'%.4f'
%
v
,
''
,
''
])
k
=
""
version
=
""
md_writer
.
value_matrix
=
values
latex_writer
.
value_matrix
=
values
# todo: make latex table look good
# print(latex_writer.dumps())
return
md_writer
.
dumps
()
\ No newline at end of file
main.py
View file @
7a058d69
...
@@ -17,7 +17,6 @@ def parse_args():
...
@@ -17,7 +17,6 @@ def parse_args():
parser
.
add_argument
(
'--num_fewshot'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--num_fewshot'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
1234
)
parser
.
add_argument
(
'--output_path'
,
default
=
None
)
parser
.
add_argument
(
'--output_path'
,
default
=
None
)
parser
.
add_argument
(
'--limit'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'--limit'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'--no_cache'
,
action
=
"store_true"
)
parser
.
add_argument
(
'--no_cache'
,
action
=
"store_true"
)
...
@@ -26,63 +25,29 @@ def parse_args():
...
@@ -26,63 +25,29 @@ def parse_args():
def
main
():
def
main
():
args
=
parse_args
()
args
=
parse_args
()
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
lm
=
models
.
get_model
(
args
.
model
).
create_from_arg_string
(
args
.
model_args
,
{
assert
not
args
.
provide_description
# not implemented
'batch_size'
:
args
.
batch_size
,
'device'
:
args
.
device
})
if
args
.
limit
:
if
args
.
limit
:
print
(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
print
(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if
not
args
.
no_cache
:
lm
=
base
.
CachingLM
(
lm
,
'lm_cache/'
+
args
.
model
+
'_'
+
args
.
model_args
.
replace
(
'='
,
'-'
).
replace
(
','
,
'_'
).
replace
(
'/'
,
'-'
)
+
'.db'
)
if
args
.
tasks
==
"all_tasks"
:
if
args
.
tasks
==
"all_tasks"
:
task_names
=
tasks
.
ALL_TASKS
task_names
=
tasks
.
ALL_TASKS
else
:
else
:
task_names
=
args
.
tasks
.
split
(
","
)
task_names
=
args
.
tasks
.
split
(
","
)
task_dict
=
tasks
.
get_task_dict
(
task_names
)
results
=
evaluator
.
evaluate
(
lm
,
task_dict
,
args
.
provide_description
,
args
.
num_fewshot
,
args
.
limit
)
results
=
evaluator
.
simple_
evaluate
(
args
.
model
,
args
.
model_args
,
task_names
,
args
.
num_fewshot
,
args
.
batch_size
,
args
.
device
,
args
.
no_cache
,
args
.
limit
)
dumped
=
json
.
dumps
(
results
,
indent
=
2
)
dumped
=
json
.
dumps
(
results
,
indent
=
2
)
print
(
dumped
)
print
(
dumped
)
if
args
.
output_path
:
if
args
.
output_path
:
with
open
(
args
.
output_path
,
"w"
)
as
f
:
with
open
(
args
.
output_path
,
"w"
)
as
f
:
f
.
write
(
dumped
)
f
.
write
(
dumped
)
# MAKE TABLE
from
pytablewriter
import
MarkdownTableWriter
,
LatexTableWriter
md_writer
=
MarkdownTableWriter
()
latex_writer
=
LatexTableWriter
()
md_writer
.
headers
=
[
"Task"
,
"Version"
,
"Metric"
,
"Value"
,
""
,
"Stderr"
]
latex_writer
.
headers
=
[
"Task"
,
"Version"
,
"Metric"
,
"Value"
,
""
,
"Stderr"
]
values
=
[]
for
k
,
dic
in
results
[
"results"
].
items
():
version
=
results
[
"versions"
][
k
]
for
m
,
v
in
dic
.
items
():
if
m
.
endswith
(
"_stderr"
):
continue
if
m
+
"_stderr"
in
dic
:
se
=
dic
[
m
+
"_stderr"
]
values
.
append
([
k
,
version
,
m
,
'%.4f'
%
v
,
'±'
,
'%.4f'
%
se
])
else
:
values
.
append
([
k
,
version
,
m
,
'%.4f'
%
v
,
''
,
''
])
k
=
""
version
=
""
md_writer
.
value_matrix
=
values
latex_writer
.
value_matrix
=
values
# todo: make latex table look good
# print(latex_writer.dumps())
print
(
f
"
{
args
.
model
}
(
{
args
.
model_args
}
), limit:
{
args
.
limit
}
, provide_description:
{
args
.
provide_description
}
, num_fewshot:
{
args
.
num_fewshot
}
, batch_size:
{
args
.
batch_size
}
"
)
print
(
f
"
{
args
.
model
}
(
{
args
.
model_args
}
), limit:
{
args
.
limit
}
, provide_description:
{
args
.
provide_description
}
, num_fewshot:
{
args
.
num_fewshot
}
, batch_size:
{
args
.
batch_size
}
"
)
print
(
md_writer
.
dumps
(
))
print
(
evaluator
.
make_table
(
results
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment