Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
073b0808
Unverified
Commit
073b0808
authored
Apr 28, 2022
by
Oskar van der Wal
Committed by
GitHub
Apr 28, 2022
Browse files
Merge branch 'bigscience-workshop:master' into master
parents
2d861a29
29bff88d
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
166 additions
and
67 deletions
+166
-67
lm_eval/base.py
lm_eval/base.py
+20
-14
lm_eval/evaluator.py
lm_eval/evaluator.py
+2
-1
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+24
-11
lm_eval/models/gptj.py
lm_eval/models/gptj.py
+25
-10
lm_eval/models/t0.py
lm_eval/models/t0.py
+26
-8
lm_eval/models/t5.py
lm_eval/models/t5.py
+18
-8
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+2
-2
lm_eval/tasks/drop.py
lm_eval/tasks/drop.py
+2
-2
lm_eval/tasks/gem_asset_turk.py
lm_eval/tasks/gem_asset_turk.py
+2
-2
lm_eval/tasks/gem_webnlg.py
lm_eval/tasks/gem_webnlg.py
+2
-2
lm_eval/tasks/glue.py
lm_eval/tasks/glue.py
+2
-2
lm_eval/tasks/superglue.py
lm_eval/tasks/superglue.py
+36
-0
lm_eval/tasks/wino_bias.py
lm_eval/tasks/wino_bias.py
+2
-2
templates/new_task.py
templates/new_task.py
+3
-3
No files found.
lm_eval/base.py
View file @
073b0808
...
@@ -353,11 +353,13 @@ class BaseLM(LM):
...
@@ -353,11 +353,13 @@ class BaseLM(LM):
for
context
,
request_args
in
tqdm
(
reord
.
get_reordered
()):
for
context
,
request_args
in
tqdm
(
reord
.
get_reordered
()):
stopping_criteria
=
request_args
[
"stopping_criteria"
]
stopping_criteria
=
request_args
[
"stopping_criteria"
]
max_generation_length
=
request_args
[
"max_generation_length"
]
max_generation_length
=
request_args
[
"max_generation_length"
]
num_fewshot
=
request_args
[
"num_fewshot"
]
assert
isinstance
(
stopping_criteria
,
str
)
or
stopping_criteria
is
None
assert
isinstance
(
stopping_criteria
,
str
)
or
stopping_criteria
is
None
assert
(
assert
(
isinstance
(
max_generation_length
,
int
)
or
max_generation_length
is
None
isinstance
(
max_generation_length
,
int
)
or
max_generation_length
is
None
)
)
assert
isinstance
(
num_fewshot
,
int
)
or
num_fewshot
is
None
if
stopping_criteria
is
None
:
if
stopping_criteria
is
None
:
until
=
[
self
.
eot_token
]
until
=
[
self
.
eot_token
]
...
@@ -382,9 +384,10 @@ class BaseLM(LM):
...
@@ -382,9 +384,10 @@ class BaseLM(LM):
context_enc
,
context_enc
,
max_length
,
max_length
,
torch
.
tensor
(
primary_until
),
torch
.
tensor
(
primary_until
),
num_fewshot
,
)
)
s
=
self
.
tok_decode
(
cont
[
0
]
.
tolist
()
[
context_enc
.
shape
[
1
]
:]
)
s
=
self
.
tok_decode
(
cont
.
tolist
())
for
term
in
until
:
for
term
in
until
:
s
=
s
.
split
(
term
)[
0
]
s
=
s
.
split
(
term
)[
0
]
...
@@ -536,7 +539,7 @@ class Task(abc.ABC):
...
@@ -536,7 +539,7 @@ class Task(abc.ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
,
args
):
"""Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
...
@@ -546,6 +549,8 @@ class Task(abc.ABC):
...
@@ -546,6 +549,8 @@ class Task(abc.ABC):
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
:param args: dict
The specifics of the context, including number of few shots.
"""
"""
pass
pass
...
@@ -689,11 +694,9 @@ class PromptSourceTask(Task):
...
@@ -689,11 +694,9 @@ class PromptSourceTask(Task):
def
stopping_criteria
(
self
)
->
Optional
[
str
]:
def
stopping_criteria
(
self
)
->
Optional
[
str
]:
"""Denote where the generation should end.
"""Denote where the generation should end.
For example, for coqa, this is '
\n
Q:' and for drop '.'.
By default, its "
\n
###
\n
".
By default, its None, meaning to generate up to max or EOT, whichever comes first.
"""
"""
return
None
return
"
\n
###
\n
"
def
max_generation_length
(
self
)
->
Optional
[
int
]:
def
max_generation_length
(
self
)
->
Optional
[
int
]:
"""Denote where the max length of the generation if it is obvious from the task."""
"""Denote where the max length of the generation if it is obvious from the task."""
...
@@ -724,7 +727,7 @@ class PromptSourceTask(Task):
...
@@ -724,7 +727,7 @@ class PromptSourceTask(Task):
text
,
_
=
self
.
prompt
.
apply
(
doc
)
text
,
_
=
self
.
prompt
.
apply
(
doc
)
return
text
return
text
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
,
args
):
"""Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
...
@@ -734,6 +737,8 @@ class PromptSourceTask(Task):
...
@@ -734,6 +737,8 @@ class PromptSourceTask(Task):
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
:param args: dict
The specifics of the context, including number of few shots.
"""
"""
_requests
=
[]
_requests
=
[]
answer_choices_list
=
self
.
prompt
.
get_answer_choices_list
(
doc
)
answer_choices_list
=
self
.
prompt
.
get_answer_choices_list
(
doc
)
...
@@ -749,6 +754,7 @@ class PromptSourceTask(Task):
...
@@ -749,6 +754,7 @@ class PromptSourceTask(Task):
request_args
=
{
request_args
=
{
"stopping_criteria"
:
self
.
stopping_criteria
(),
"stopping_criteria"
:
self
.
stopping_criteria
(),
"max_generation_length"
:
self
.
max_generation_length
(),
"max_generation_length"
:
self
.
max_generation_length
(),
"num_fewshot"
:
args
[
"num_fewshot"
],
}
}
cont_request
=
rf
.
greedy_until
(
ctx
,
request_args
)
cont_request
=
rf
.
greedy_until
(
ctx
,
request_args
)
_requests
.
append
(
cont_request
)
_requests
.
append
(
cont_request
)
...
@@ -915,12 +921,12 @@ class PromptSourceTask(Task):
...
@@ -915,12 +921,12 @@ class PromptSourceTask(Task):
if
num_fewshot
==
0
:
if
num_fewshot
==
0
:
labeled_examples
=
""
labeled_examples
=
""
fewshotex
,
fewshotidx
,
fewshotsource
=
[],
[],
None
fewshotex
,
fewshotidx
,
self
.
fewshotsource
=
[],
[],
None
else
:
else
:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if
self
.
has_training_docs
():
if
self
.
has_training_docs
():
fewshotex
,
fewshotidx
=
self
.
fewshot_examples
(
k
=
num_fewshot
,
rnd
=
rnd
)
fewshotex
,
fewshotidx
=
self
.
fewshot_examples
(
k
=
num_fewshot
,
rnd
=
rnd
)
fewshotsource
=
"train"
self
.
fewshotsource
=
"train"
else
:
else
:
if
self
.
_fewshot_docs
is
None
:
if
self
.
_fewshot_docs
is
None
:
self
.
_fewshot_docs
=
list
(
self
.
_fewshot_docs
=
list
(
...
@@ -929,18 +935,18 @@ class PromptSourceTask(Task):
...
@@ -929,18 +935,18 @@ class PromptSourceTask(Task):
else
self
.
test_docs
()
else
self
.
test_docs
()
)
)
if
self
.
has_validation_docs
():
if
self
.
has_validation_docs
():
fewshotsource
=
"val"
self
.
fewshotsource
=
"val"
elif
self
.
test_docs
():
elif
self
.
test_docs
():
fewshotsource
=
"test"
self
.
fewshotsource
=
"test"
fewshotex
,
fewshotidx
=
self
.
_get_fewshot_examples
(
fewshotex
,
fewshotidx
=
self
.
_get_fewshot_examples
(
self
.
_fewshot_docs
,
k
=
num_fewshot
+
1
,
rnd
=
rnd
self
.
_fewshot_docs
,
k
=
num_fewshot
+
1
,
rnd
=
rnd
)
)
fewshotex
,
fewshotidx
=
[
fewshotex
,
fewshotidx
=
zip
(
*
[
(
shot
,
idx
)
(
shot
,
idx
)
for
shot
,
idx
in
zip
(
fewshotex
,
fewshotidx
)
for
shot
,
idx
in
zip
(
fewshotex
,
fewshotidx
)
if
shot
!=
doc
if
shot
!=
doc
]
]
)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex
,
fewshotidx
=
(
fewshotex
,
fewshotidx
=
(
fewshotex
[:
num_fewshot
],
fewshotex
[:
num_fewshot
],
...
@@ -966,7 +972,7 @@ class PromptSourceTask(Task):
...
@@ -966,7 +972,7 @@ class PromptSourceTask(Task):
ctx
,
ctx
,
{
{
"fewshot_idx"
:
fewshotidx
,
"fewshot_idx"
:
fewshotidx
,
"fewshot_source"
:
fewshotsource
,
"fewshot_source"
:
self
.
fewshotsource
,
"fewshot_num"
:
num_fewshot
,
"fewshot_num"
:
num_fewshot
,
"ctx"
:
ctx
,
"ctx"
:
ctx
,
},
},
...
...
lm_eval/evaluator.py
View file @
073b0808
...
@@ -206,7 +206,8 @@ def evaluate(
...
@@ -206,7 +206,8 @@ def evaluate(
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
)
)
fewshotex_logging_info
[
"doc_id"
]
=
original_doc_id
fewshotex_logging_info
[
"doc_id"
]
=
original_doc_id
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
args
=
{
"num_fewshot"
:
num_fewshot
}
reqs
=
task
.
construct_requests
(
doc
,
ctx
,
args
)
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
reqs
=
[
reqs
]
reqs
=
[
reqs
]
for
i
,
req
in
enumerate
(
reqs
):
for
i
,
req
in
enumerate
(
reqs
):
...
...
lm_eval/models/gpt2.py
View file @
073b0808
...
@@ -12,6 +12,7 @@ class HFLM(BaseLM):
...
@@ -12,6 +12,7 @@ class HFLM(BaseLM):
subfolder
=
None
,
subfolder
=
None
,
tokenizer
=
None
,
tokenizer
=
None
,
batch_size
=
1
,
batch_size
=
1
,
parallelize
=
False
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -32,7 +33,7 @@ class HFLM(BaseLM):
...
@@ -32,7 +33,7 @@ class HFLM(BaseLM):
self
.
gpt2
=
transformers
.
AutoModelForCausalLM
.
from_pretrained
(
self
.
gpt2
=
transformers
.
AutoModelForCausalLM
.
from_pretrained
(
pretrained
,
pretrained
,
revision
=
revision
+
(
"/"
+
subfolder
if
subfolder
is
not
None
else
""
),
revision
=
revision
+
(
"/"
+
subfolder
if
subfolder
is
not
None
else
""
),
)
.
to
(
self
.
device
)
)
self
.
gpt2
.
eval
()
self
.
gpt2
.
eval
()
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
...
@@ -68,9 +69,11 @@ class HFLM(BaseLM):
...
@@ -68,9 +69,11 @@ class HFLM(BaseLM):
self
.
batch_size_per_gpu
=
batch_size
# todo: adaptive batch size
self
.
batch_size_per_gpu
=
batch_size
# todo: adaptive batch size
# TODO: fix multi-gpu
# TODO: fix multi-gpu
# gpus = torch.cuda.device_count()
if
parallelize
:
# if gpus > 1:
self
.
gpt2
.
parallelize
()
# self.gpt2 = nn.DataParallel(self.gpt2)
self
.
_device
=
torch
.
device
(
'cuda:0'
)
else
:
self
.
gpt2
.
to
(
self
.
_device
)
@
property
@
property
def
eot_token
(
self
):
def
eot_token
(
self
):
...
@@ -146,16 +149,26 @@ class HFLM(BaseLM):
...
@@ -146,16 +149,26 @@ class HFLM(BaseLM):
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
])
])
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
):
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
,
num_fewshot
):
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
return
self
.
gpt2
.
generate
(
context
,
max_length
=
max_length
,
stopping_criteria
=
stopping_criteria
,
do_sample
=
False
,
)
if
num_fewshot
==
0
:
generations
=
self
.
gpt2
.
generate
(
context
,
max_length
=
max_length
,
eos_token_id
=
self
.
eot_token_id
,
do_sample
=
False
,
)
else
:
generations
=
self
.
gpt2
.
generate
(
context
,
max_length
=
max_length
,
stopping_criteria
=
stopping_criteria
,
do_sample
=
False
,
)
# Remove the context from the generations
return
generations
[
0
,
context
.
shape
[
1
]
:]
# for backwards compatibility
# for backwards compatibility
GPT2LM
=
HFLM
GPT2LM
=
HFLM
lm_eval/models/gptj.py
View file @
073b0808
...
@@ -8,6 +8,7 @@ class GPTJLM(BaseLM):
...
@@ -8,6 +8,7 @@ class GPTJLM(BaseLM):
self
,
self
,
device
=
"cuda"
,
device
=
"cuda"
,
batch_size
=
1
,
batch_size
=
1
,
parallelize
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -35,9 +36,11 @@ class GPTJLM(BaseLM):
...
@@ -35,9 +36,11 @@ class GPTJLM(BaseLM):
self
.
batch_size_per_gpu
=
batch_size
# todo: adaptive batch size
self
.
batch_size_per_gpu
=
batch_size
# todo: adaptive batch size
# TODO: fix multi-gpu
# TODO: fix multi-gpu
# gpus = torch.cuda.device_count()
if
parallelize
:
# if gpus > 1:
self
.
gptj
.
parallelize
()
# self.gptj = nn.DataParallel(self.gptj)
self
.
_device
=
torch
.
device
(
'cuda:0'
)
else
:
self
.
gptj
.
to
(
self
.
_device
)
@
property
@
property
def
eot_token
(
self
):
def
eot_token
(
self
):
...
@@ -113,11 +116,23 @@ class GPTJLM(BaseLM):
...
@@ -113,11 +116,23 @@ class GPTJLM(BaseLM):
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
])
])
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
):
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
,
num_fewshot
):
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
return
self
.
gptj
.
generate
(
context
,
if
num_fewshot
==
0
:
max_length
=
max_length
,
generations
=
self
.
gptj
.
generate
(
stopping_criteria
=
stopping_criteria
,
context
,
do_sample
=
False
,
max_length
=
max_length
,
)
eos_token_id
=
self
.
eot_token_id
,
do_sample
=
False
,
)
else
:
generations
=
self
.
gptj
.
generate
(
context
,
max_length
=
max_length
,
stopping_criteria
=
stopping_criteria
,
do_sample
=
False
,
)
# Remove the context from the generations
return
generations
[
0
,
context
.
shape
[
1
]
:]
lm_eval/models/t0.py
View file @
073b0808
...
@@ -56,7 +56,7 @@ class T0LM(BaseLM):
...
@@ -56,7 +56,7 @@ class T0LM(BaseLM):
@
property
@
property
def
max_gen_toks
(
self
):
def
max_gen_toks
(
self
):
return
self
.
tokenizer
.
model_max_length
return
256
@
property
@
property
def
batch_size
(
self
):
def
batch_size
(
self
):
...
@@ -94,6 +94,14 @@ class T0LM(BaseLM):
...
@@ -94,6 +94,14 @@ class T0LM(BaseLM):
inputs
,
targets
=
zip
(
*
chunk
)
inputs
,
targets
=
zip
(
*
chunk
)
# Fill in empty encoder inputs with eos_token
inputs
=
(
f
"
{
self
.
eot_token
}
"
if
len
(
input_
)
==
0
else
input_
for
input_
in
inputs
)
inputs_tok
=
self
.
tokenizer
(
inputs_tok
=
self
.
tokenizer
(
list
(
inputs
),
list
(
inputs
),
max_length
=
self
.
max_length
,
max_length
=
self
.
max_length
,
...
@@ -172,11 +180,21 @@ class T0LM(BaseLM):
...
@@ -172,11 +180,21 @@ class T0LM(BaseLM):
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
])
])
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
):
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
,
num_fewshot
):
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
return
self
.
t0
.
generate
(
context
,
if
num_fewshot
==
0
:
max_length
=
max_length
,
generations
=
self
.
t0
.
generate
(
stopping_criteria
=
stopping_criteria
,
context
,
do_sample
=
False
,
max_length
=
max_length
,
)
eos_token_id
=
self
.
eot_token_id
,
do_sample
=
False
,
)
else
:
generations
=
self
.
t0
.
generate
(
context
,
max_length
=
max_length
,
stopping_criteria
=
stopping_criteria
,
do_sample
=
False
,
)
return
generations
[
0
]
lm_eval/models/t5.py
View file @
073b0808
...
@@ -62,7 +62,7 @@ class T5LM(BaseLM):
...
@@ -62,7 +62,7 @@ class T5LM(BaseLM):
@
property
@
property
def
max_gen_toks
(
self
):
def
max_gen_toks
(
self
):
return
self
.
tokenizer
.
model_max_length
return
256
@
property
@
property
def
batch_size
(
self
):
def
batch_size
(
self
):
...
@@ -186,11 +186,21 @@ class T5LM(BaseLM):
...
@@ -186,11 +186,21 @@ class T5LM(BaseLM):
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
])
])
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
):
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
,
num_fewshot
):
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
return
self
.
t5
.
generate
(
context
,
if
num_fewshot
==
0
:
max_length
=
max_length
,
generations
=
self
.
t5
.
generate
(
stopping_criteria
=
stopping_criteria
,
context
,
do_sample
=
False
,
max_length
=
max_length
,
)
eos_token_id
=
self
.
eot_token_id
,
do_sample
=
False
,
)
else
:
generations
=
self
.
t5
.
generate
(
context
,
max_length
=
max_length
,
stopping_criteria
=
stopping_criteria
,
do_sample
=
False
,
)
return
generations
[
0
]
lm_eval/tasks/coqa.py
View file @
073b0808
...
@@ -90,8 +90,8 @@ class CoQA(PromptSourceTask):
...
@@ -90,8 +90,8 @@ class CoQA(PromptSourceTask):
"f1"
:
f1_sum
/
max
(
1
,
len
(
gold_list
)),
"f1"
:
f1_sum
/
max
(
1
,
len
(
gold_list
)),
}
}
def
stopping_criteria
(
self
):
#
def stopping_criteria(self):
return
"
\n\n
"
#
return "\n\n"
# def construct_requests(self, doc, ctx):
# def construct_requests(self, doc, ctx):
# """Uses RequestFactory to construct Requests and returns an iterable of
# """Uses RequestFactory to construct Requests and returns an iterable of
...
...
lm_eval/tasks/drop.py
View file @
073b0808
...
@@ -92,8 +92,8 @@ class DROP(PromptSourceTask):
...
@@ -92,8 +92,8 @@ class DROP(PromptSourceTask):
# """
# """
# conts = [rf.greedy_until(ctx, ["."])]
# conts = [rf.greedy_until(ctx, ["."])]
# return conts
# return conts
def
stopping_criteria
(
self
):
#
def stopping_criteria(self):
return
"."
#
return "."
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
"""Take a single document and the LM results and evaluates, returning a
...
...
lm_eval/tasks/gem_asset_turk.py
View file @
073b0808
...
@@ -78,8 +78,8 @@ class AssetTurk(PromptSourceTask):
...
@@ -78,8 +78,8 @@ class AssetTurk(PromptSourceTask):
def
test_docs
(
self
):
def
test_docs
(
self
):
return
self
.
dataset
[
str
(
self
.
SPLIT
)]
return
self
.
dataset
[
str
(
self
.
SPLIT
)]
def
stopping_criteria
(
self
):
#
def stopping_criteria(self):
return
None
#
return None
def
max_generation_length
(
self
):
def
max_generation_length
(
self
):
return
200
return
200
...
...
lm_eval/tasks/gem_webnlg.py
View file @
073b0808
...
@@ -70,8 +70,8 @@ class WebNLG(PromptSourceTask):
...
@@ -70,8 +70,8 @@ class WebNLG(PromptSourceTask):
else
:
else
:
return
self
.
dataset
[
"test"
]
return
self
.
dataset
[
"test"
]
def
stopping_criteria
(
self
):
#
def stopping_criteria(self):
return
None
#
return None
def
max_generation_length
(
self
):
def
max_generation_length
(
self
):
return
250
return
250
...
...
lm_eval/tasks/glue.py
View file @
073b0808
...
@@ -236,8 +236,8 @@ class MRPC(PromptSourceTask):
...
@@ -236,8 +236,8 @@ class MRPC(PromptSourceTask):
def
has_test_docs
(
self
):
def
has_test_docs
(
self
):
return
False
return
False
def
stopping_criteria
(
self
):
#
def stopping_criteria(self):
return
"
\n
"
#
return "\n
###\n
"
def
training_docs
(
self
):
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
if
self
.
_training_docs
is
None
:
...
...
lm_eval/tasks/superglue.py
View file @
073b0808
...
@@ -305,3 +305,39 @@ class SGWinogradSchemaChallenge(PromptSourceTask):
...
@@ -305,3 +305,39 @@ class SGWinogradSchemaChallenge(PromptSourceTask):
def
aggregation
(
self
):
def
aggregation
(
self
):
return
{
"acc"
:
mean
}
return
{
"acc"
:
mean
}
class
WinogenderSchemaDiagnostics
(
PromptSourceTask
):
VERSION
=
0
DATASET_PATH
=
"super_glue"
DATASET_NAME
=
"axg"
def
has_training_docs
(
self
):
return
False
def
has_validation_docs
(
self
):
return
False
def
has_test_docs
(
self
):
return
True
def
test_docs
(
self
):
return
self
.
dataset
[
"test"
]
class
BroadcoverageDiagnostics
(
PromptSourceTask
):
VERSION
=
0
DATASET_PATH
=
"super_glue"
DATASET_NAME
=
"axb"
def
has_training_docs
(
self
):
return
False
def
has_validation_docs
(
self
):
return
False
def
has_test_docs
(
self
):
return
True
def
test_docs
(
self
):
return
self
.
dataset
[
"test"
]
lm_eval/tasks/wino_bias.py
View file @
073b0808
...
@@ -54,8 +54,8 @@ class WinoBias(PromptSourceTask):
...
@@ -54,8 +54,8 @@ class WinoBias(PromptSourceTask):
def
test_docs
(
self
):
def
test_docs
(
self
):
return
self
.
dataset
[
"test"
]
return
self
.
dataset
[
"test"
]
def
stopping_criteria
(
self
):
#
def stopping_criteria(self):
return
"
\n
"
#
return "\n"
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
"""Take a single document and the LM results and evaluates, returning a
...
...
templates/new_task.py
View file @
073b0808
...
@@ -73,10 +73,10 @@ class NewTask(PromptSourceTask):
...
@@ -73,10 +73,10 @@ class NewTask(PromptSourceTask):
return
self
.
dataset
[
"test"
]
return
self
.
dataset
[
"test"
]
def
stopping_criteria
(
self
):
def
stopping_criteria
(
self
):
#
TODO: Denote the string where the generation should be split
.
#
Only define this method when you want to control few-shot generations on specific tokens
.
#
For example, for `coqa`, this is '\nQ:' and for `drop` '.
'.
#
The default is set to '\n###\n
'.
# NOTE: You may delete this function if the task does not required generation.
# NOTE: You may delete this function if the task does not required generation.
return
None
return
"
\n
###
\n
"
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
"""Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment