Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
879aabd6
Commit
879aabd6
authored
Feb 27, 2021
by
Leo Gao
Browse files
Implement fewshot description experiment
parent
2c41ecf6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
83 additions
and
0 deletions
+83
-0
scripts/fewshot_description_experiment.py
scripts/fewshot_description_experiment.py
+83
-0
No files found.
scripts/fewshot_description_experiment.py
0 → 100644
View file @
879aabd6
import
argparse
import
json
import
numpy
as
np
import
random
import
itertools
import
collections
import
logging
from
lm_eval
import
models
,
tasks
,
evaluator
,
base
logging
.
getLogger
(
"openai"
).
setLevel
(
logging
.
WARNING
)
fewshot_descriptions
=
[
"foo"
,
"bar"
]
task
=
"lambada"
num_fewshot
=
0
model
=
"gpt2"
model_args
=
""
limit
=
None
no_cache
=
False
class
CustomDescTask
:
def
__init__
(
self
,
task
,
desc
):
self
.
task
=
task
self
.
desc
=
desc
def
fewshot_description
():
return
self
.
desc
self
.
task
.
fewshot_description
=
fewshot_description
def
__getattr__
(
self
,
attr
):
return
getattr
(
self
.
task
,
attr
)
def
main
():
random
.
seed
(
42
)
np
.
random
.
seed
(
42
)
lm
=
models
.
get_model
(
model
).
create_from_arg_string
(
model_args
)
if
limit
:
print
(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if
not
no_cache
:
lm
=
base
.
CachingLM
(
lm
,
'lm_cache/'
+
model
+
'_'
+
model_args
.
replace
(
'='
,
'-'
).
replace
(
','
,
'_'
)
+
'.db'
)
task_dict
=
tasks
.
get_task_dict
([
task
])
for
desc
in
fewshot_descriptions
:
custom_task_dict
=
{
k
:
CustomDescTask
(
v
,
desc
)
for
k
,
v
in
task_dict
.
items
()}
results
=
evaluator
.
evaluate
(
lm
,
custom_task_dict
,
True
,
num_fewshot
,
limit
)
dumped
=
json
.
dumps
(
results
,
indent
=
2
)
print
(
'Description:'
,
desc
)
print
(
dumped
)
# MAKE TABLE
from
pytablewriter
import
MarkdownTableWriter
writer
=
MarkdownTableWriter
()
writer
.
headers
=
[
"Task"
,
"Metric"
,
"Value"
]
values
=
[]
for
k
,
dic
in
results
.
items
():
for
m
,
v
in
dic
.
items
():
values
.
append
([
k
,
m
,
'%.4f'
%
v
])
k
=
""
writer
.
value_matrix
=
values
print
(
writer
.
dumps
())
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment