Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
b4ad893c
Commit
b4ad893c
authored
Apr 25, 2022
by
ken
Browse files
Merge master
parents
8c83a821
20820c3c
Changes
35
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
283 additions
and
165 deletions
+283
-165
lm_eval/datasets/wikitext/__init__.py
lm_eval/datasets/wikitext/__init__.py
+0
-0
lm_eval/evaluator.py
lm_eval/evaluator.py
+50
-17
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+4
-0
lm_eval/models/t0.py
lm_eval/models/t0.py
+69
-48
lm_eval/models/t5.py
lm_eval/models/t5.py
+79
-52
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+2
-0
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+6
-13
lm_eval/tasks/gem_webnlg.py
lm_eval/tasks/gem_webnlg.py
+37
-0
lm_eval/tasks/hendrycks_ethics.py
lm_eval/tasks/hendrycks_ethics.py
+4
-6
lm_eval/tasks/hendrycks_math.py
lm_eval/tasks/hendrycks_math.py
+3
-3
lm_eval/tasks/wikitext.py
lm_eval/tasks/wikitext.py
+4
-4
lm_eval/tasks/wsc273.py
lm_eval/tasks/wsc273.py
+2
-2
main.py
main.py
+20
-17
scripts/write_out.py
scripts/write_out.py
+2
-2
setup.py
setup.py
+1
-1
No files found.
lm_eval/datasets/wikitext/__init__.py
0 → 100644
View file @
b4ad893c
lm_eval/evaluator.py
View file @
b4ad893c
...
@@ -173,10 +173,6 @@ def evaluate(
...
@@ -173,10 +173,6 @@ def evaluate(
# get lists of each type of request
# get lists of each type of request
for
task_prompt_name
,
task
in
task_dict_items
:
for
task_prompt_name
,
task
in
task_dict_items
:
# if task.is_generation_task():
# print(f"WARNING: Skipping generation prompt {task.prompt.name}.")
# continue
versions
[
task_prompt_name
]
=
task
.
VERSION
versions
[
task_prompt_name
]
=
task
.
VERSION
# default to test doc, fall back to val doc if validation unavailable
# default to test doc, fall back to val doc if validation unavailable
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
...
@@ -188,7 +184,7 @@ def evaluate(
...
@@ -188,7 +184,7 @@ def evaluate(
raise
RuntimeError
(
"Task has neither test_docs nor validation_docs"
)
raise
RuntimeError
(
"Task has neither test_docs nor validation_docs"
)
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
task_docs
=
list
(
task_doc_func
())
task_docs
=
list
(
enumerate
(
list
(
task_doc_func
())
))
rnd
=
random
.
Random
()
rnd
=
random
.
Random
()
rnd
.
seed
(
42
)
rnd
.
seed
(
42
)
rnd
.
shuffle
(
task_docs
)
rnd
.
shuffle
(
task_docs
)
...
@@ -199,14 +195,17 @@ def evaluate(
...
@@ -199,14 +195,17 @@ def evaluate(
else
""
else
""
)
)
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
for
doc_id
,
(
original_doc_id
,
doc
)
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)
):
if
task
.
invalid_doc_for_prompt
(
doc
):
if
task
.
invalid_doc_for_prompt
(
doc
):
continue
continue
docs
[(
task_prompt_name
,
doc_id
)]
=
doc
docs
[(
task_prompt_name
,
doc_id
)]
=
doc
ctx
=
task
.
fewshot_context
(
ctx
,
fewshotex_logging_info
=
task
.
fewshot_context
(
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
)
)
fewshotex_logging_info
[
"doc_id"
]
=
original_doc_id
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
reqs
=
[
reqs
]
reqs
=
[
reqs
]
...
@@ -215,7 +214,7 @@ def evaluate(
...
@@ -215,7 +214,7 @@ def evaluate(
# i: index in requests for a single task instance
# i: index in requests for a single task instance
# doc_id: unique id that we can get back to a doc using `docs`
# doc_id: unique id that we can get back to a doc using `docs`
requests_origin
[
req
.
request_type
].
append
(
requests_origin
[
req
.
request_type
].
append
(
(
i
,
task_prompt_name
,
doc
,
doc_id
)
(
i
,
task_prompt_name
,
doc
,
doc_id
,
fewshotex_logging_info
)
)
)
# all responses for each (task, doc)
# all responses for each (task, doc)
...
@@ -234,33 +233,57 @@ def evaluate(
...
@@ -234,33 +233,57 @@ def evaluate(
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)
]
]
for
resp
,
(
i
,
task_prompt_name
,
doc
,
doc_id
)
in
zip
(
for
resp
,
(
i
,
task_prompt_name
,
doc
,
doc_id
,
fewshotex_logging_info
)
in
zip
(
resps
,
requests_origin
[
reqtype
]
resps
,
requests_origin
[
reqtype
]
):
):
process_res_queue
[(
task_prompt_name
,
doc_id
)].
append
((
i
,
resp
))
process_res_queue
[(
task_prompt_name
,
doc_id
)].
append
(
(
i
,
resp
,
fewshotex_logging_info
)
)
vals
=
collections
.
defaultdict
(
list
)
vals
=
collections
.
defaultdict
(
list
)
# unpack results and sort back in order and return control to Task
# unpack results and sort back in order and return control to Task
for
(
task_prompt_name
,
doc_id
),
requests
in
process_res_queue
.
items
():
examples
=
[]
requests
.
sort
(
key
=
lambda
x
:
x
[
0
])
for
(
task_prompt_name
,
doc_id
),
per_doc_requests
in
process_res_queue
.
items
():
requests
=
[
x
[
1
]
for
x
in
requests
]
per_doc_requests
.
sort
(
key
=
lambda
x
:
x
[
0
])
per_doc_results
=
[
x
[
1
]
for
x
in
per_doc_requests
]
fewshot_logging_info
=
[
x
[
2
]
for
x
in
per_doc_requests
][
0
]
task
=
task_dict
[
task_prompt_name
]
task
=
task_dict
[
task_prompt_name
]
doc
=
docs
[(
task_prompt_name
,
doc_id
)]
doc
=
docs
[(
task_prompt_name
,
doc_id
)]
metrics
=
task
.
process_results
(
doc
,
requests
)
output
=
task
.
process_results
(
doc
,
per_doc_results
)
if
task
.
save_examples
:
metrics
,
example
=
output
example
.
update
(
fewshot_logging_info
)
example
.
update
(
task
.
get_logging_info
())
examples
.
append
(
example
)
else
:
metrics
=
output
example
=
fewshot_logging_info
example
.
update
(
task
.
get_logging_info
())
examples
.
append
(
example
)
for
metric
,
value
in
metrics
.
items
():
for
metric
,
value
in
metrics
.
items
():
vals
[(
task_prompt_name
,
metric
)].
append
(
value
)
vals
[(
task_prompt_name
,
metric
)].
append
(
value
)
# aggregate results
# aggregate results
metric_results
=
[]
for
(
task_prompt_name
,
metric
),
items
in
vals
.
items
():
for
(
task_prompt_name
,
metric
),
items
in
vals
.
items
():
task_name
,
prompt_name
=
task_prompt_name
.
split
(
"+"
)
task_name
,
prompt_name
=
task_prompt_name
.
split
(
"+"
)
results
[
task_prompt_name
][
"task_name"
]
=
task_name
results
[
task_prompt_name
][
"task_name"
]
=
task_name
results
[
task_prompt_name
][
"prompt_name"
]
=
prompt_name
results
[
task_prompt_name
][
"prompt_name"
]
=
prompt_name
task
=
task_dict
[
task_prompt_name
]
task
=
task_dict
[
task_prompt_name
]
results
[
task_prompt_name
][
metric
]
=
task
.
aggregation
()[
metric
](
items
)
results
[
task_prompt_name
][
metric
]
=
task
.
aggregation
()[
metric
](
items
)
_metric_results
=
{
"task_name"
:
task_name
,
"prompt_name"
:
prompt_name
,
metric
:
task
.
aggregation
()[
metric
](
items
),
**
task
.
get_logging_info
(),
}
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
# so we run them less iterations. still looking for a cleaner way to do this
stderr
=
lm_eval
.
metrics
.
stderr_for_metric
(
stderr
=
lm_eval
.
metrics
.
stderr_for_metric
(
...
@@ -271,8 +294,18 @@ def evaluate(
...
@@ -271,8 +294,18 @@ def evaluate(
)
)
if
stderr
is
not
None
:
if
stderr
is
not
None
:
results
[
task_prompt_name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
results
[
task_prompt_name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
_metric_results
[
metric
+
"_stderr"
]
=
stderr
(
items
)
return
{
"results"
:
dict
(
results
),
"versions"
:
dict
(
versions
)}
metric_results
.
append
(
_metric_results
)
return
{
# List of results that tracks the averages per model and prompt.
"results"
:
metric_results
,
"versions"
:
dict
(
versions
),
# List of all prompt x doc examples with additional information in it.
"examples"
:
examples
,
# Original results used for generating the table when running this file.
"table_results"
:
dict
(
results
),
}
def
make_table
(
result_dict
):
def
make_table
(
result_dict
):
...
@@ -293,7 +326,7 @@ def make_table(result_dict):
...
@@ -293,7 +326,7 @@ def make_table(result_dict):
]
]
values
=
[]
values
=
[]
for
k
,
dic
in
result_dict
[
"results"
].
items
():
for
k
,
dic
in
result_dict
[
"
table_
results"
].
items
():
version
=
result_dict
[
"versions"
][
k
]
version
=
result_dict
[
"versions"
][
k
]
for
m
,
v
in
dic
.
items
():
for
m
,
v
in
dic
.
items
():
if
m
.
endswith
(
"_stderr"
):
if
m
.
endswith
(
"_stderr"
):
...
...
lm_eval/models/gpt2.py
View file @
b4ad893c
...
@@ -72,6 +72,10 @@ class HFLM(BaseLM):
...
@@ -72,6 +72,10 @@ class HFLM(BaseLM):
# if gpus > 1:
# if gpus > 1:
# self.gpt2 = nn.DataParallel(self.gpt2)
# self.gpt2 = nn.DataParallel(self.gpt2)
@
property
def
eot_token
(
self
):
return
self
.
tokenizer
.
eos_token
@
property
@
property
def
eot_token_id
(
self
):
def
eot_token_id
(
self
):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
...
...
lm_eval/models/t0.py
View file @
b4ad893c
...
@@ -2,37 +2,36 @@ import transformers
...
@@ -2,37 +2,36 @@ import transformers
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
lm_eval.base
import
LM
from
lm_eval.base
import
Base
LM
from
lm_eval
import
utils
from
lm_eval
import
utils
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
numpy
as
np
import
numpy
as
np
import
math
import
math
class
T0LM
(
LM
):
class
T0LM
(
Base
LM
):
MAX_GEN_TOKS
=
256
#
MAX_GEN_TOKS = 256
MAX_INP_LENGTH
=
512
#
MAX_INP_LENGTH = 512
VOCAB_SIZE
=
32100
#
VOCAB_SIZE = 32100
EOT_TOKEN_ID
=
1
#
EOT_TOKEN_ID = 1
def
__init__
(
self
,
device
=
'cuda'
,
parallelize
=
False
,
pretrained
=
't0'
,
batch_size
=
1
):
def
__init__
(
self
,
device
=
'cuda'
,
parallelize
=
False
,
pretrained
=
't0'
,
batch_size
=
1
):
super
().
__init__
()
super
().
__init__
()
if
device
:
if
device
:
self
.
device
=
torch
.
device
(
device
)
self
.
_
device
=
torch
.
device
(
device
)
else
:
else
:
self
.
device
=
torch
.
device
(
'cuda'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'cpu'
)
self
.
_
device
=
torch
.
device
(
'cuda'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'cpu'
)
print
(
pretrained
)
self
.
t0
=
transformers
.
AutoModelForSeq2SeqLM
.
from_pretrained
(
pretrained
)
self
.
t0
=
transformers
.
AutoModelForSeq2SeqLM
.
from_pretrained
(
pretrained
)
self
.
t0
.
eval
()
self
.
t0
.
eval
()
if
parallelize
==
"True"
:
if
parallelize
==
"True"
:
print
(
parallelize
)
self
.
t0
.
parallelize
()
self
.
t0
.
parallelize
()
self
.
device
=
torch
.
device
(
'cuda:0'
)
self
.
_
device
=
torch
.
device
(
'cuda:0'
)
else
:
else
:
self
.
t0
.
to
(
self
.
device
)
self
.
t0
.
to
(
self
.
_
device
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
pretrained
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
pretrained
)
self
.
max_length
=
self
.
MAX_INP_LENGTH
#
self.max_length = self.MAX_INP_LENGTH
self
.
batch_size
=
int
(
batch_size
)
self
.
batch_size
=
int
(
batch_size
)
...
@@ -42,6 +41,53 @@ class T0LM(LM):
...
@@ -42,6 +41,53 @@ class T0LM(LM):
args2
=
{
k
:
v
for
k
,
v
in
additional_config
.
items
()
if
v
is
not
None
}
args2
=
{
k
:
v
for
k
,
v
in
additional_config
.
items
()
if
v
is
not
None
}
return
cls
(
**
args
,
**
args2
)
return
cls
(
**
args
,
**
args2
)
@
property
def
eot_token
(
self
):
return
self
.
tokenizer
.
eos_token
@
property
def
eot_token_id
(
self
):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return
self
.
tokenizer
.
eos_token_id
@
property
def
max_length
(
self
):
return
self
.
tokenizer
.
model_max_length
@
property
def
max_gen_toks
(
self
):
return
self
.
tokenizer
.
model_max_length
@
property
def
batch_size
(
self
):
# TODO: fix multi-gpu
return
self
.
_batch_size
# * gpus
@
property
def
device
(
self
):
# TODO: fix multi-gpu
return
self
.
_device
def
tok_encode
(
self
,
string
:
str
):
return
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
False
)
def
tok_decode
(
self
,
tokens
):
return
self
.
tokenizer
.
decode
(
tokens
)
def
_model_call
(
self
,
inputs_tok
,
targets_tok
):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with
torch
.
no_grad
():
return
self
.
t0
(
**
inputs_tok
,
labels
=
targets_tok
[
"input_ids"
]
)
def
loglikelihood
(
self
,
requests
):
def
loglikelihood
(
self
,
requests
):
res
=
[]
res
=
[]
for
chunk
in
tqdm
(
utils
.
chunks
(
requests
,
self
.
batch_size
),
total
=
math
.
ceil
(
len
(
requests
)
/
self
.
batch_size
)):
for
chunk
in
tqdm
(
utils
.
chunks
(
requests
,
self
.
batch_size
),
total
=
math
.
ceil
(
len
(
requests
)
/
self
.
batch_size
)):
...
@@ -62,7 +108,7 @@ class T0LM(LM):
...
@@ -62,7 +108,7 @@ class T0LM(LM):
targets_tok
=
self
.
tokenizer
(
targets_tok
=
self
.
tokenizer
(
list
(
targets
),
list
(
targets
),
max_length
=
self
.
MAX_GEN_TOKS
,
max_length
=
self
.
max_gen_toks
,
padding
=
True
,
padding
=
True
,
# truncation=True,
# truncation=True,
add_special_tokens
=
False
,
add_special_tokens
=
False
,
...
@@ -72,11 +118,7 @@ class T0LM(LM):
...
@@ -72,11 +118,7 @@ class T0LM(LM):
for
key
in
targets_tok
:
for
key
in
targets_tok
:
targets_tok
[
key
]
=
targets_tok
[
key
][:,
-
(
self
.
max_length
-
1
)
:]
targets_tok
[
key
]
=
targets_tok
[
key
][:,
-
(
self
.
max_length
-
1
)
:]
with
torch
.
no_grad
():
outputs
=
self
.
_model_call
(
inputs_tok
,
targets_tok
)
outputs
=
self
.
t0
(
**
inputs_tok
,
labels
=
targets_tok
[
"input_ids"
]
)
log_softmaxes
=
F
.
log_softmax
(
outputs
.
logits
,
dim
=-
1
)
log_softmaxes
=
F
.
log_softmax
(
outputs
.
logits
,
dim
=-
1
)
...
@@ -103,9 +145,6 @@ class T0LM(LM):
...
@@ -103,9 +145,6 @@ class T0LM(LM):
res
.
append
(
answer
)
res
.
append
(
answer
)
return
res
return
res
def
loglikelihood_rolling
(
self
,
requests
):
raise
NotImplementedError
def
_get_stopping_criteria
(
self
,
stopping_criteria_ids
):
def
_get_stopping_criteria
(
self
,
stopping_criteria_ids
):
class
MultitokenEOSCriteria
(
transformers
.
StoppingCriteria
):
class
MultitokenEOSCriteria
(
transformers
.
StoppingCriteria
):
...
@@ -133,29 +172,11 @@ class T0LM(LM):
...
@@ -133,29 +172,11 @@ class T0LM(LM):
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
])
])
def
greedy_until
(
self
,
requests
):
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
):
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
res
=
[]
return
self
.
t0
.
generate
(
context
,
for
context
,
until
in
tqdm
(
requests
):
max_length
=
max_length
,
if
isinstance
(
until
,
str
):
until
=
[
until
]
stopping_criteria
=
stopping_criteria
,
do_sample
=
False
,
context_enc
=
self
.
tokenizer
(
context
,
return_tensors
=
"pt"
).
to
(
self
.
device
).
input_ids
)
stopping_criteria_ids
=
self
.
tokenizer
.
encode
(
until
[
0
])
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
cont
=
self
.
t0
.
generate
(
context_enc
,
max_length
=
self
.
MAX_GEN_TOKS
,
stopping_criteria
=
stopping_criteria
,
do_sample
=
False
)
s
=
self
.
tokenizer
.
decode
(
cont
[
0
].
tolist
())
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
until
),
s
)
res
.
append
(
s
)
return
res
\ No newline at end of file
lm_eval/models/t5.py
View file @
b4ad893c
...
@@ -2,39 +2,44 @@ import transformers
...
@@ -2,39 +2,44 @@ import transformers
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
lm_eval.base
import
LM
from
lm_eval.base
import
Base
LM
from
lm_eval
import
utils
from
lm_eval
import
utils
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
numpy
as
np
import
numpy
as
np
import
math
import
math
class
T5LM
(
LM
):
class
T5LM
(
BaseLM
):
MAX_GEN_TOKS
=
256
# MAX_GEN_TOKS = 256
MAX_INP_LENGTH
=
512
# MAX_INP_LENGTH = 512
VOCAB_SIZE
=
32128
# VOCAB_SIZE = 32128
EOT_TOKEN_ID
=
1
# EOT_TOKEN_ID = 1
def
__init__
(
self
,
device
=
'cuda'
,
parallelize
=
False
,
pretrained
=
't5'
,
batch_size
=
1
):
def
__init__
(
self
,
device
=
'cuda'
,
parallelize
=
False
,
pretrained
=
't5'
,
batch_size
=
1
):
super
().
__init__
()
super
().
__init__
()
if
device
:
if
device
:
self
.
device
=
torch
.
device
(
device
)
self
.
_
device
=
torch
.
device
(
device
)
else
:
else
:
self
.
device
=
torch
.
device
(
'cuda'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'cpu'
)
self
.
_
device
=
torch
.
device
(
'cuda'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'cpu'
)
print
(
pretrained
)
self
.
t5
=
transformers
.
AutoModelForSeq2SeqLM
.
from_pretrained
(
pretrained
)
self
.
t5
=
transformers
.
AutoModelForSeq2SeqLM
.
from_pretrained
(
pretrained
)
self
.
t5
.
eval
()
self
.
t5
.
eval
()
if
parallelize
==
"True"
:
if
parallelize
==
"True"
:
print
(
parallelize
)
self
.
t5
.
parallelize
()
self
.
t5
.
parallelize
()
self
.
device
=
torch
.
device
(
'cuda:0'
)
self
.
_
device
=
torch
.
device
(
'cuda:0'
)
else
:
else
:
self
.
t5
.
to
(
self
.
device
)
self
.
t5
.
to
(
self
.
_
device
)
self
.
tokenizer
=
transformers
.
T5
Tokenizer
Fast
.
from_pretrained
(
pretrained
)
self
.
tokenizer
=
transformers
.
Auto
Tokenizer
.
from_pretrained
(
pretrained
)
self
.
max_length
=
self
.
MAX_INP_LENGTH
#
self.max_length = self.MAX_INP_LENGTH
self
.
batch_size
=
int
(
batch_size
)
self
.
_
batch_size
=
int
(
batch_size
)
@
classmethod
@
classmethod
def
create_from_arg_string
(
cls
,
arg_string
,
additional_config
=
{}):
def
create_from_arg_string
(
cls
,
arg_string
,
additional_config
=
{}):
...
@@ -42,6 +47,53 @@ class T5LM(LM):
...
@@ -42,6 +47,53 @@ class T5LM(LM):
args2
=
{
k
:
v
for
k
,
v
in
additional_config
.
items
()
if
v
is
not
None
}
args2
=
{
k
:
v
for
k
,
v
in
additional_config
.
items
()
if
v
is
not
None
}
return
cls
(
**
args
,
**
args2
)
return
cls
(
**
args
,
**
args2
)
@
property
def
eot_token
(
self
):
return
self
.
tokenizer
.
eos_token
@
property
def
eot_token_id
(
self
):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return
self
.
tokenizer
.
eos_token_id
@
property
def
max_length
(
self
):
return
self
.
tokenizer
.
model_max_length
@
property
def
max_gen_toks
(
self
):
return
self
.
tokenizer
.
model_max_length
@
property
def
batch_size
(
self
):
# TODO: fix multi-gpu
return
self
.
_batch_size
# * gpus
@
property
def
device
(
self
):
# TODO: fix multi-gpu
return
self
.
_device
def
tok_encode
(
self
,
string
:
str
):
return
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
False
)
def
tok_decode
(
self
,
tokens
):
return
self
.
tokenizer
.
decode
(
tokens
)
def
_model_call
(
self
,
inputs_tok
,
targets_tok
):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with
torch
.
no_grad
():
return
self
.
t5
(
**
inputs_tok
,
labels
=
targets_tok
[
"input_ids"
]
)
def
loglikelihood
(
self
,
requests
):
def
loglikelihood
(
self
,
requests
):
res
=
[]
res
=
[]
for
chunk
in
tqdm
(
utils
.
chunks
(
requests
,
self
.
batch_size
),
total
=
math
.
ceil
(
len
(
requests
)
/
self
.
batch_size
)):
for
chunk
in
tqdm
(
utils
.
chunks
(
requests
,
self
.
batch_size
),
total
=
math
.
ceil
(
len
(
requests
)
/
self
.
batch_size
)):
...
@@ -62,7 +114,7 @@ class T5LM(LM):
...
@@ -62,7 +114,7 @@ class T5LM(LM):
targets_tok
=
self
.
tokenizer
(
targets_tok
=
self
.
tokenizer
(
list
(
targets
),
list
(
targets
),
max_length
=
self
.
MAX_GEN_TOKS
,
max_length
=
self
.
max_gen_toks
,
padding
=
True
,
padding
=
True
,
# truncation=True,
# truncation=True,
add_special_tokens
=
False
,
add_special_tokens
=
False
,
...
@@ -72,11 +124,7 @@ class T5LM(LM):
...
@@ -72,11 +124,7 @@ class T5LM(LM):
for
key
in
targets_tok
:
for
key
in
targets_tok
:
targets_tok
[
key
]
=
targets_tok
[
key
][:,
-
(
self
.
max_length
-
1
)
:]
targets_tok
[
key
]
=
targets_tok
[
key
][:,
-
(
self
.
max_length
-
1
)
:]
with
torch
.
no_grad
():
outputs
=
self
.
_model_call
(
inputs_tok
,
targets_tok
)
outputs
=
self
.
t5
(
**
inputs_tok
,
labels
=
targets_tok
[
"input_ids"
]
)
log_softmaxes
=
F
.
log_softmax
(
outputs
.
logits
,
dim
=-
1
)
log_softmaxes
=
F
.
log_softmax
(
outputs
.
logits
,
dim
=-
1
)
...
@@ -103,9 +151,6 @@ class T5LM(LM):
...
@@ -103,9 +151,6 @@ class T5LM(LM):
res
.
append
(
answer
)
res
.
append
(
answer
)
return
res
return
res
def
loglikelihood_rolling
(
self
,
requests
):
raise
NotImplementedError
def
_get_stopping_criteria
(
self
,
stopping_criteria_ids
):
def
_get_stopping_criteria
(
self
,
stopping_criteria_ids
):
class
MultitokenEOSCriteria
(
transformers
.
StoppingCriteria
):
class
MultitokenEOSCriteria
(
transformers
.
StoppingCriteria
):
...
@@ -133,29 +178,11 @@ class T5LM(LM):
...
@@ -133,29 +178,11 @@ class T5LM(LM):
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
])
])
def
greedy_until
(
self
,
requests
):
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
):
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
res
=
[]
return
self
.
t5
.
generate
(
context
,
for
context
,
until
in
tqdm
(
requests
):
max_length
=
max_length
,
if
isinstance
(
until
,
str
):
until
=
[
until
]
stopping_criteria
=
stopping_criteria
,
do_sample
=
False
,
context_enc
=
self
.
tokenizer
(
context
,
return_tensors
=
"pt"
).
to
(
self
.
device
).
input_ids
)
stopping_criteria_ids
=
self
.
tokenizer
.
encode
(
until
[
0
])
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
cont
=
self
.
t5
.
generate
(
context_enc
,
max_length
=
self
.
MAX_GEN_TOKS
,
stopping_criteria
=
stopping_criteria
,
do_sample
=
False
)
s
=
self
.
tokenizer
.
decode
(
cont
[
0
].
tolist
())
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
until
),
s
)
res
.
append
(
s
)
return
res
\ No newline at end of file
lm_eval/tasks/__init__.py
View file @
b4ad893c
...
@@ -53,6 +53,7 @@ from . import asdiv
...
@@ -53,6 +53,7 @@ from . import asdiv
from
.
import
gsm8k
from
.
import
gsm8k
from
.
import
storycloze
from
.
import
storycloze
from
.
import
hans
from
.
import
hans
from
.
import
gem_webnlg
from
.
import
gem_xsum
from
.
import
gem_xsum
# from . import e2e_nlg_cleaned
# from . import e2e_nlg_cleaned
...
@@ -109,6 +110,7 @@ TASK_REGISTRY = {
...
@@ -109,6 +110,7 @@ TASK_REGISTRY = {
"wsc"
:
superglue
.
SGWinogradSchemaChallenge
,
"wsc"
:
superglue
.
SGWinogradSchemaChallenge
,
# Order by benchmark/genre?
# Order by benchmark/genre?
"coqa"
:
coqa
.
CoQA
,
"coqa"
:
coqa
.
CoQA
,
"GEM/web_nlg"
:
gem_webnlg
.
WebNLG
,
"drop"
:
drop
.
DROP
,
"drop"
:
drop
.
DROP
,
"lambada"
:
lambada
.
LAMBADA
,
"lambada"
:
lambada
.
LAMBADA
,
"lambada_cloze"
:
lambada_cloze
.
LAMBADA_cloze
,
"lambada_cloze"
:
lambada_cloze
.
LAMBADA_cloze
,
...
...
lm_eval/tasks/coqa.py
View file @
b4ad893c
...
@@ -118,25 +118,18 @@ class CoQA(PromptSourceTask):
...
@@ -118,25 +118,18 @@ class CoQA(PromptSourceTask):
"""
"""
target
=
self
.
doc_to_target
(
doc
).
strip
()
target
=
self
.
doc_to_target
(
doc
).
strip
()
pred
=
results
[
0
].
strip
().
split
(
"
\n
"
)[
0
]
pred
=
results
[
0
].
strip
().
split
(
"
\n
"
)[
0
]
print
(
"*"
*
80
)
print
(
f
"DOC:
{
doc
}
"
)
# print(f"PS: {self.prompt.apply(doc)}")
print
(
f
"TEXT:
{
self
.
doc_to_text
(
doc
)
}
"
)
print
(
f
"TARGET:
{
target
}
END TARGET"
)
print
(
f
"PRED:
{
pred
}
END PRED"
)
print
(
"*"
*
80
)
# turn_id = len(doc["questions"]["input_text"])
# gold_list = self.get_answers(doc, turn_id)
# TODO: Add HF metrics mapped from promptsource metadata.
scores
=
self
.
compute_scores
([
target
],
pred
)
scores
=
self
.
compute_scores
([
target
],
pred
)
return
{
out
=
{
"f1"
:
scores
[
"f1"
],
"f1"
:
scores
[
"f1"
],
"em"
:
scores
[
"em"
],
"em"
:
scores
[
"em"
],
}
}
if
self
.
save_examples
:
example
=
{
"target"
:
target
,
"pred"
:
pred
}
return
out
,
example
return
out
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
return
{
"f1"
:
True
,
"f1"
:
True
,
...
...
lm_eval/tasks/gem_webnlg.py
0 → 100644
View file @
b4ad893c
from
lm_eval.base
import
PromptSourceTask
class
WebNLG
(
PromptSourceTask
):
VERSION
=
0
DATASET_PATH
=
"GEM/web_nlg"
DATASET_NAME
=
"en"
def
has_training_docs
(
self
):
return
False
def
has_validation_docs
(
self
):
return
True
def
has_test_docs
(
self
):
return
True
def
training_docs
(
self
):
if
self
.
has_training_docs
():
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
return
self
.
_training_docs
def
validation_docs
(
self
):
if
self
.
has_validation_docs
():
return
self
.
dataset
[
"validation"
]
def
test_docs
(
self
):
if
self
.
has_test_docs
():
return
self
.
dataset
[
"test"
]
def
stopping_criteria
(
self
):
return
'*'
def
max_generation_length
(
self
):
return
250
lm_eval/tasks/hendrycks_ethics.py
View file @
b4ad893c
...
@@ -277,20 +277,18 @@ class EthicsUtilitarianism(Ethics):
...
@@ -277,20 +277,18 @@ class EthicsUtilitarianism(Ethics):
DATASET_NAME
=
"utilitarianism"
DATASET_NAME
=
"utilitarianism"
def
training_docs
(
self
):
def
training_docs
(
self
):
rnd
=
random
.
Random
()
for
doc
in
self
.
dataset
[
"train"
]:
for
doc
in
self
.
dataset
[
"train"
]:
yield
self
.
_process_doc
(
doc
,
rnd
)
yield
self
.
_process_doc
(
doc
)
def
validation_docs
(
self
):
def
validation_docs
(
self
):
raise
NotImplementedError
raise
NotImplementedError
def
test_docs
(
self
):
def
test_docs
(
self
):
rnd
=
random
.
Random
()
for
doc
in
self
.
dataset
[
"test"
]:
for
doc
in
self
.
dataset
[
"test"
]:
yield
self
.
_process_doc
(
doc
,
rnd
)
yield
self
.
_process_doc
(
doc
)
def
_process_doc
(
self
,
doc
,
rnd
):
def
_process_doc
(
self
,
doc
):
rnd
.
seed
(
doc
[
"activity"
])
rnd
=
random
.
Random
(
doc
[
"activity"
])
scenarios
=
[
doc
[
"activity"
],
doc
[
"baseline"
]]
scenarios
=
[
doc
[
"activity"
],
doc
[
"baseline"
]]
ordering
=
[
0
,
1
]
ordering
=
[
0
,
1
]
rnd
.
shuffle
(
ordering
)
rnd
.
shuffle
(
ordering
)
...
...
lm_eval/tasks/hendrycks_math.py
View file @
b4ad893c
...
@@ -38,15 +38,15 @@ class Math(Task):
...
@@ -38,15 +38,15 @@ class Math(Task):
return
True
return
True
def
training_docs
(
self
):
def
training_docs
(
self
):
return
map
(
self
.
_
load
_doc
,
self
.
dataset
[
"train"
])
return
map
(
self
.
_
process
_doc
,
self
.
dataset
[
"train"
])
def
validation_docs
(
self
):
def
validation_docs
(
self
):
return
NotImplemented
return
NotImplemented
def
test_docs
(
self
):
def
test_docs
(
self
):
return
map
(
self
.
_
load
_doc
,
self
.
dataset
[
"test"
])
return
map
(
self
.
_
process
_doc
,
self
.
dataset
[
"test"
])
def
_
load
_doc
(
self
,
doc
):
def
_
process
_doc
(
self
,
doc
):
doc
[
"answer"
]
=
self
.
remove_boxed
(
doc
[
"answer"
]
=
self
.
remove_boxed
(
self
.
last_boxed_only_string
(
doc
[
"solution"
]))
self
.
last_boxed_only_string
(
doc
[
"solution"
]))
return
doc
return
doc
...
...
lm_eval/tasks/wikitext.py
View file @
b4ad893c
...
@@ -76,15 +76,15 @@ class WikiText(PerplexityTask):
...
@@ -76,15 +76,15 @@ class WikiText(PerplexityTask):
return
True
return
True
def
training_docs
(
self
):
def
training_docs
(
self
):
return
map
(
self
.
_
load
_doc
,
self
.
dataset
[
"train"
])
return
map
(
self
.
_
process
_doc
,
self
.
dataset
[
"train"
])
def
validation_docs
(
self
):
def
validation_docs
(
self
):
return
map
(
self
.
_
load
_doc
,
self
.
dataset
[
"validation"
])
return
map
(
self
.
_
process
_doc
,
self
.
dataset
[
"validation"
])
def
test_docs
(
self
):
def
test_docs
(
self
):
return
map
(
self
.
_
load
_doc
,
self
.
dataset
[
"test"
])
return
map
(
self
.
_
process
_doc
,
self
.
dataset
[
"test"
])
def
_
load
_doc
(
self
,
doc
):
def
_
process
_doc
(
self
,
doc
):
return
doc
[
"page"
]
return
doc
[
"page"
]
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
...
...
lm_eval/tasks/wsc273.py
View file @
b4ad893c
...
@@ -53,9 +53,9 @@ class WinogradSchemaChallenge273(Task):
...
@@ -53,9 +53,9 @@ class WinogradSchemaChallenge273(Task):
return
True
return
True
def
test_docs
(
self
):
def
test_docs
(
self
):
return
map
(
self
.
_
load
_doc
,
self
.
dataset
[
"test"
])
return
map
(
self
.
_
process
_doc
,
self
.
dataset
[
"test"
])
def
_
load
_doc
(
self
,
doc
):
def
_
process
_doc
(
self
,
doc
):
# The HF implementation of `wsc273` is not `partial evaluation` friendly.
# The HF implementation of `wsc273` is not `partial evaluation` friendly.
doc
[
"text"
]
=
doc
[
"text"
].
replace
(
" "
,
" "
)
doc
[
"text"
]
=
doc
[
"text"
].
replace
(
" "
,
" "
)
doc
[
"options"
][
0
]
=
self
.
__normalize_option
(
doc
,
doc
[
"options"
][
0
])
doc
[
"options"
][
0
]
=
self
.
__normalize_option
(
doc
,
doc
[
"options"
][
0
])
...
...
main.py
View file @
b4ad893c
...
@@ -9,27 +9,29 @@ logging.getLogger("openai").setLevel(logging.WARNING)
...
@@ -9,27 +9,29 @@ logging.getLogger("openai").setLevel(logging.WARNING)
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--model
'
,
required
=
True
)
parser
.
add_argument
(
"
--model
"
,
required
=
True
)
parser
.
add_argument
(
'
--model_args
'
,
default
=
""
)
parser
.
add_argument
(
"
--model_args
"
,
default
=
""
)
parser
.
add_argument
(
'
--tasks
'
,
default
=
"all_tasks"
)
parser
.
add_argument
(
"
--tasks
"
,
default
=
"all_tasks"
)
parser
.
add_argument
(
'
--provide_description
'
,
action
=
"store_true"
)
parser
.
add_argument
(
"
--provide_description
"
,
action
=
"store_true"
)
parser
.
add_argument
(
'
--num_fewshot
'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"
--num_fewshot
"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'
--batch_size
'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"
--batch_size
"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'
--device
'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"
--device
"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'
--output_path
'
,
default
=
None
)
parser
.
add_argument
(
"
--output_path
"
,
default
=
None
)
parser
.
add_argument
(
'
--limit
'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"
--limit
"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'
--no_cache
'
,
action
=
"store_true"
)
parser
.
add_argument
(
"
--no_cache
"
,
action
=
"store_true"
)
parser
.
add_argument
(
'
--description_dict_path
'
,
default
=
None
)
parser
.
add_argument
(
"
--description_dict_path
"
,
default
=
None
)
parser
.
add_argument
(
'
--check_integrity
'
,
action
=
"store_true"
)
parser
.
add_argument
(
"
--check_integrity
"
,
action
=
"store_true"
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
main
():
def
main
():
args
=
parse_args
()
args
=
parse_args
()
assert
not
args
.
provide_description
# not implemented
assert
not
args
.
provide_description
# not implemented
if
args
.
limit
:
if
args
.
limit
:
print
(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
print
(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if
args
.
tasks
==
"all_tasks"
:
if
args
.
tasks
==
"all_tasks"
:
task_names
=
tasks
.
ALL_TASKS
task_names
=
tasks
.
ALL_TASKS
...
@@ -38,7 +40,7 @@ def main():
...
@@ -38,7 +40,7 @@ def main():
description_dict
=
{}
description_dict
=
{}
if
args
.
description_dict_path
:
if
args
.
description_dict_path
:
with
open
(
args
.
description_dict_path
,
'r'
)
as
f
:
with
open
(
args
.
description_dict_path
,
"r"
)
as
f
:
description_dict
=
json
.
load
(
f
)
description_dict
=
json
.
load
(
f
)
results
=
evaluator
.
simple_evaluate
(
results
=
evaluator
.
simple_evaluate
(
...
@@ -51,11 +53,12 @@ def main():
...
@@ -51,11 +53,12 @@ def main():
no_cache
=
args
.
no_cache
,
no_cache
=
args
.
no_cache
,
limit
=
args
.
limit
,
limit
=
args
.
limit
,
description_dict
=
description_dict
,
description_dict
=
description_dict
,
check_integrity
=
args
.
check_integrity
check_integrity
=
args
.
check_integrity
,
)
)
print
(
results
)
dumped
=
json
.
dumps
(
results
,
indent
=
2
)
dumped
=
json
.
dumps
(
results
,
indent
=
2
)
print
(
dumped
)
print
(
dumped
)
if
args
.
output_path
:
if
args
.
output_path
:
...
...
scripts/write_out.py
View file @
b4ad893c
...
@@ -56,11 +56,11 @@ def main():
...
@@ -56,11 +56,11 @@ def main():
docs
=
join_iters
(
iters
)
docs
=
join_iters
(
iters
)
description
=
description_dict
[
task_name
]
if
description_dict
and
task_name
in
description_dict
else
""
description
=
description_dict
[
task_name
]
if
description_dict
and
task_name
in
description_dict
else
""
task_name
=
task_name
.
replace
(
'/'
,
'_'
)
with
open
(
os
.
path
.
join
(
args
.
output_base_path
,
task_name
),
"w"
)
as
f
:
with
open
(
os
.
path
.
join
(
args
.
output_base_path
,
task_name
),
"w"
)
as
f
:
for
i
,
doc
in
zip
(
range
(
args
.
num_examples
),
docs
)
if
args
.
num_examples
>
0
else
enumerate
(
docs
):
for
i
,
doc
in
zip
(
range
(
args
.
num_examples
),
docs
)
if
args
.
num_examples
>
0
else
enumerate
(
docs
):
f
.
write
(
EXAMPLE_DIVIDER
.
format
(
i
=
i
))
f
.
write
(
EXAMPLE_DIVIDER
.
format
(
i
=
i
))
ctx
=
task
.
fewshot_context
(
ctx
,
_
=
task
.
fewshot_context
(
doc
=
doc
,
doc
=
doc
,
num_fewshot
=
args
.
num_fewshot
,
num_fewshot
=
args
.
num_fewshot
,
rnd
=
rnd
,
rnd
=
rnd
,
...
...
setup.py
View file @
b4ad893c
...
@@ -37,7 +37,6 @@ setuptools.setup(
...
@@ -37,7 +37,6 @@ setuptools.setup(
"pycountry==20.7.3"
,
"pycountry==20.7.3"
,
"numexpr==2.7.2"
,
"numexpr==2.7.2"
,
"lm_dataformat==0.0.20"
,
"lm_dataformat==0.0.20"
,
"pytest==6.2.3"
,
"pybind11==2.6.2"
,
"pybind11==2.6.2"
,
"tqdm-multiprocess==0.0.11"
,
"tqdm-multiprocess==0.0.11"
,
"zstandard==0.15.2"
,
"zstandard==0.15.2"
,
...
@@ -51,4 +50,5 @@ setuptools.setup(
...
@@ -51,4 +50,5 @@ setuptools.setup(
dependency_links
=
[
dependency_links
=
[
"https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt"
,
"https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt"
,
],
],
extras_require
=
{
'dev'
:
[
'pytest'
,
'black'
]}
)
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment