Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
eda365f6
Commit
eda365f6
authored
Apr 28, 2022
by
jon-tow
Browse files
Fix max generation limit
parent
22155f7d
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
17 additions
and
118 deletions
+17
-118
lm_eval/base.py
lm_eval/base.py
+8
-79
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+1
-1
lm_eval/models/gptj.py
lm_eval/models/gptj.py
+1
-1
lm_eval/models/t5.py
lm_eval/models/t5.py
+0
-1
lm_eval/tasks/drop.py
lm_eval/tasks/drop.py
+0
-2
lm_eval/tasks/e2e_nlg_cleaned.py
lm_eval/tasks/e2e_nlg_cleaned.py
+2
-4
lm_eval/tasks/gem_asset_turk.py
lm_eval/tasks/gem_asset_turk.py
+0
-6
lm_eval/tasks/gem_mlsum.py
lm_eval/tasks/gem_mlsum.py
+0
-6
lm_eval/tasks/gem_webnlg.py
lm_eval/tasks/gem_webnlg.py
+0
-6
lm_eval/tasks/gem_xsum.py
lm_eval/tasks/gem_xsum.py
+0
-2
lm_eval/tasks/glue.py
lm_eval/tasks/glue.py
+0
-3
lm_eval/tasks/wino_bias.py
lm_eval/tasks/wino_bias.py
+0
-3
templates/new_task.py
templates/new_task.py
+5
-4
No files found.
lm_eval/base.py
View file @
eda365f6
...
@@ -375,11 +375,10 @@ class BaseLM(LM):
...
@@ -375,11 +375,10 @@ class BaseLM(LM):
).
to
(
self
.
device
)
).
to
(
self
.
device
)
if
max_generation_length
is
None
:
if
max_generation_length
is
None
:
max_length
=
context_enc
.
shape
[
1
]
+
self
.
max_gen_tok
s
max_length
=
self
.
max_gen_tok
else
:
else
:
max_length
=
min
(
max_length
=
max_generation_length
max_generation_length
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
)
cont
=
self
.
_model_generate
(
cont
=
self
.
_model_generate
(
context_enc
,
context_enc
,
max_length
,
max_length
,
...
@@ -595,78 +594,6 @@ class Task(abc.ABC):
...
@@ -595,78 +594,6 @@ class Task(abc.ABC):
)
)
return
""
return
""
@
utils
.
positional_deprecated
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
assert
(
rnd
is
not
None
),
"A `random.Random` generator argument must be provided to `rnd`"
assert
not
provide_description
,
(
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if
provide_description
is
not
None
:
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description
=
description
+
"
\n\n
"
if
description
else
""
if
num_fewshot
==
0
:
labeled_examples
=
""
else
:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if
self
.
has_training_docs
():
fewshotex
=
self
.
fewshot_examples
(
k
=
num_fewshot
,
rnd
=
rnd
)
else
:
if
self
.
_fewshot_docs
is
None
:
self
.
_fewshot_docs
=
list
(
self
.
validation_docs
()
if
self
.
has_validation_docs
()
else
self
.
test_docs
()
)
fewshotex
=
rnd
.
sample
(
self
.
_fewshot_docs
,
num_fewshot
+
1
)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex
=
[
x
for
x
in
fewshotex
if
x
!=
doc
][:
num_fewshot
]
# See Webson & Pavlick (2022) https://arxiv.org/pdf/2109.01247.pdf
# for justification of this separator.
example_separator
=
"
\n
###
\n
"
labeled_examples
=
(
example_separator
.
join
(
[
self
.
doc_to_text
(
doc
)
+
self
.
doc_to_target
(
doc
)
for
doc
in
fewshotex
]
)
+
example_separator
)
example
=
self
.
doc_to_text
(
doc
)
return
description
+
labeled_examples
+
example
class
PromptSourceTask
(
Task
):
class
PromptSourceTask
(
Task
):
"""These are the metrics from promptsource that we have
"""These are the metrics from promptsource that we have
...
@@ -691,10 +618,12 @@ class PromptSourceTask(Task):
...
@@ -691,10 +618,12 @@ class PromptSourceTask(Task):
self
.
prompt
=
prompt
self
.
prompt
=
prompt
self
.
save_examples
=
save_examples
self
.
save_examples
=
save_examples
def
stopping_criteria
(
self
)
->
Optional
[
str
]:
"""Denote where the generation should end.
By default, its "
\n
###
\n
".
def
stopping_criteria
(
self
)
->
Optional
[
str
]:
"""
Denote where the generation should end based on the few-shot example
separator: "
\n
###
\n
".
TODO: Handle other separators in the future.
"""
"""
return
"
\n
###
\n
"
return
"
\n
###
\n
"
...
...
lm_eval/models/gpt2.py
View file @
eda365f6
...
@@ -151,7 +151,7 @@ class HFLM(BaseLM):
...
@@ -151,7 +151,7 @@ class HFLM(BaseLM):
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
,
num_fewshot
):
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
,
num_fewshot
):
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
max_length
=
max_length
+
context
.
size
(
1
)
if
num_fewshot
==
0
:
if
num_fewshot
==
0
:
generations
=
self
.
gpt2
.
generate
(
generations
=
self
.
gpt2
.
generate
(
context
,
context
,
...
...
lm_eval/models/gptj.py
View file @
eda365f6
...
@@ -118,7 +118,7 @@ class GPTJLM(BaseLM):
...
@@ -118,7 +118,7 @@ class GPTJLM(BaseLM):
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
,
num_fewshot
):
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
,
num_fewshot
):
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
max_length
=
max_length
+
context
.
size
(
1
)
if
num_fewshot
==
0
:
if
num_fewshot
==
0
:
generations
=
self
.
gptj
.
generate
(
generations
=
self
.
gptj
.
generate
(
context
,
context
,
...
...
lm_eval/models/t5.py
View file @
eda365f6
...
@@ -188,7 +188,6 @@ class T5LM(BaseLM):
...
@@ -188,7 +188,6 @@ class T5LM(BaseLM):
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
,
num_fewshot
):
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
,
num_fewshot
):
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
if
num_fewshot
==
0
:
if
num_fewshot
==
0
:
generations
=
self
.
t5
.
generate
(
generations
=
self
.
t5
.
generate
(
context
,
context
,
...
...
lm_eval/tasks/drop.py
View file @
eda365f6
...
@@ -92,8 +92,6 @@ class DROP(PromptSourceTask):
...
@@ -92,8 +92,6 @@ class DROP(PromptSourceTask):
# """
# """
# conts = [rf.greedy_until(ctx, ["."])]
# conts = [rf.greedy_until(ctx, ["."])]
# return conts
# return conts
# def stopping_criteria(self):
# return "."
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
"""Take a single document and the LM results and evaluates, returning a
...
...
lm_eval/tasks/e2e_nlg_cleaned.py
View file @
eda365f6
...
@@ -61,9 +61,6 @@ class E2E_NLG_Cleaned(PromptSourceTask):
...
@@ -61,9 +61,6 @@ class E2E_NLG_Cleaned(PromptSourceTask):
def
max_generation_length
(
self
):
def
max_generation_length
(
self
):
return
64
return
64
# def stopping_criteria(self):
# return '\n\n'
def
invalid_doc_for_prompt
(
self
,
doc
)
->
bool
:
def
invalid_doc_for_prompt
(
self
,
doc
)
->
bool
:
"""The QA prompts are not applicable to all the examples, we want to filter these out."""
"""The QA prompts are not applicable to all the examples, we want to filter these out."""
return
self
.
prompt
.
name
.
endswith
(
"_qa"
)
or
self
.
prompt
.
name
==
"family_friendly_yes_no"
return
self
.
prompt
.
name
.
endswith
(
"_qa"
)
or
self
.
prompt
.
name
==
"family_friendly_yes_no"
...
@@ -73,7 +70,7 @@ class E2E_NLG_Cleaned(PromptSourceTask):
...
@@ -73,7 +70,7 @@ class E2E_NLG_Cleaned(PromptSourceTask):
text
=
self
.
prompt
.
apply
(
doc
)[
0
]
text
=
self
.
prompt
.
apply
(
doc
)[
0
]
return
text
return
text
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
,
args
):
"""Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
...
@@ -90,6 +87,7 @@ class E2E_NLG_Cleaned(PromptSourceTask):
...
@@ -90,6 +87,7 @@ class E2E_NLG_Cleaned(PromptSourceTask):
request_args
=
{
request_args
=
{
"stopping_criteria"
:
self
.
stopping_criteria
(),
"stopping_criteria"
:
self
.
stopping_criteria
(),
"max_generation_length"
:
self
.
max_generation_length
(),
"max_generation_length"
:
self
.
max_generation_length
(),
"num_fewshot"
:
args
[
"num_fewshot"
],
}
}
# Skip examples for which the templates are not applicable
# Skip examples for which the templates are not applicable
...
...
lm_eval/tasks/gem_asset_turk.py
View file @
eda365f6
...
@@ -78,15 +78,9 @@ class AssetTurk(PromptSourceTask):
...
@@ -78,15 +78,9 @@ class AssetTurk(PromptSourceTask):
def
test_docs
(
self
):
def
test_docs
(
self
):
return
self
.
dataset
[
str
(
self
.
SPLIT
)]
return
self
.
dataset
[
str
(
self
.
SPLIT
)]
# def stopping_criteria(self):
# return None
def
max_generation_length
(
self
):
def
max_generation_length
(
self
):
return
200
return
200
# def higher_is_better(self):
# return {"bleu": True, "rouge": True}
class
AssetTest
(
AssetTurk
):
class
AssetTest
(
AssetTurk
):
SPLIT
=
"test_asset"
SPLIT
=
"test_asset"
...
...
lm_eval/tasks/gem_mlsum.py
View file @
eda365f6
...
@@ -50,9 +50,6 @@ class GEMMLSUMEsBase(PromptSourceTask):
...
@@ -50,9 +50,6 @@ class GEMMLSUMEsBase(PromptSourceTask):
if
self
.
has_test_docs
():
if
self
.
has_test_docs
():
return
self
.
dataset
[
"test"
]
return
self
.
dataset
[
"test"
]
def
stopping_criteria
(
self
):
return
"."
class
GEMMLSUMEs
(
GEMMLSUMEsBase
):
class
GEMMLSUMEs
(
GEMMLSUMEsBase
):
'''this is for train/validation/test'''
'''this is for train/validation/test'''
SPLIT
=
''
SPLIT
=
''
...
@@ -98,9 +95,6 @@ class GEMMLSUMDeBase(PromptSourceTask):
...
@@ -98,9 +95,6 @@ class GEMMLSUMDeBase(PromptSourceTask):
if
self
.
has_test_docs
():
if
self
.
has_test_docs
():
return
self
.
dataset
[
"test"
]
return
self
.
dataset
[
"test"
]
def
stopping_criteria
(
self
):
return
"."
class
GEMMLSUMDe
(
GEMMLSUMDeBase
):
class
GEMMLSUMDe
(
GEMMLSUMDeBase
):
'''this is for train/validation/test'''
'''this is for train/validation/test'''
SPLIT
=
''
SPLIT
=
''
...
...
lm_eval/tasks/gem_webnlg.py
View file @
eda365f6
...
@@ -70,15 +70,9 @@ class WebNLG(PromptSourceTask):
...
@@ -70,15 +70,9 @@ class WebNLG(PromptSourceTask):
else
:
else
:
return
self
.
dataset
[
"test"
]
return
self
.
dataset
[
"test"
]
# def stopping_criteria(self):
# return None
def
max_generation_length
(
self
):
def
max_generation_length
(
self
):
return
250
return
250
# def higher_is_better(self):
# return {"bleu": True, "rouge": True}
class
WebNLGRu
(
WebNLG
):
class
WebNLGRu
(
WebNLG
):
DATASET_NAME
=
"ru"
DATASET_NAME
=
"ru"
...
...
lm_eval/tasks/gem_xsum.py
View file @
eda365f6
...
@@ -42,8 +42,6 @@ class GEMXSUMBase(PromptSourceTask):
...
@@ -42,8 +42,6 @@ class GEMXSUMBase(PromptSourceTask):
def
has_test_docs
(
self
):
def
has_test_docs
(
self
):
return
True
return
True
def
stopping_criteria
(
self
):
return
'.'
def
training_docs
(
self
):
def
training_docs
(
self
):
if
self
.
has_training_docs
():
if
self
.
has_training_docs
():
# We cache training documents in `self._training_docs` for faster
# We cache training documents in `self._training_docs` for faster
...
...
lm_eval/tasks/glue.py
View file @
eda365f6
...
@@ -236,9 +236,6 @@ class MRPC(PromptSourceTask):
...
@@ -236,9 +236,6 @@ class MRPC(PromptSourceTask):
def
has_test_docs
(
self
):
def
has_test_docs
(
self
):
return
False
return
False
# def stopping_criteria(self):
# return "\n###\n"
def
training_docs
(
self
):
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
...
...
lm_eval/tasks/wino_bias.py
View file @
eda365f6
...
@@ -54,9 +54,6 @@ class WinoBias(PromptSourceTask):
...
@@ -54,9 +54,6 @@ class WinoBias(PromptSourceTask):
def
test_docs
(
self
):
def
test_docs
(
self
):
return
self
.
dataset
[
"test"
]
return
self
.
dataset
[
"test"
]
# def stopping_criteria(self):
# return "\n"
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
dict where keys are the names of submetrics and values are the values of
...
...
templates/new_task.py
View file @
eda365f6
...
@@ -72,11 +72,12 @@ class NewTask(PromptSourceTask):
...
@@ -72,11 +72,12 @@ class NewTask(PromptSourceTask):
# named differently than the default `"test"`.
# named differently than the default `"test"`.
return
self
.
dataset
[
"test"
]
return
self
.
dataset
[
"test"
]
def
stopping_criteria
(
self
):
def
max_generation_length
(
self
):
# Only define this method when you want to control few-shot generations on specific tokens.
# Define this method when you want to control the length of few-shot
# The default is set to '\n###\n'.
# generations on specific tokens. The default is `None` which gets mapped
# to a model's default max generation token length. E.g. see `lm_eval/models/gpt2.py:max_gen_toks()`
# NOTE: You may delete this function if the task does not required generation.
# NOTE: You may delete this function if the task does not required generation.
return
"
\n
###
\n
"
return
None
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
"""Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment