Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
eda365f6
Commit
eda365f6
authored
Apr 28, 2022
by
jon-tow
Browse files
Fix max generation limit
parent
22155f7d
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
17 additions
and
118 deletions
+17
-118
lm_eval/base.py
lm_eval/base.py
+8
-79
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+1
-1
lm_eval/models/gptj.py
lm_eval/models/gptj.py
+1
-1
lm_eval/models/t5.py
lm_eval/models/t5.py
+0
-1
lm_eval/tasks/drop.py
lm_eval/tasks/drop.py
+0
-2
lm_eval/tasks/e2e_nlg_cleaned.py
lm_eval/tasks/e2e_nlg_cleaned.py
+2
-4
lm_eval/tasks/gem_asset_turk.py
lm_eval/tasks/gem_asset_turk.py
+0
-6
lm_eval/tasks/gem_mlsum.py
lm_eval/tasks/gem_mlsum.py
+0
-6
lm_eval/tasks/gem_webnlg.py
lm_eval/tasks/gem_webnlg.py
+0
-6
lm_eval/tasks/gem_xsum.py
lm_eval/tasks/gem_xsum.py
+0
-2
lm_eval/tasks/glue.py
lm_eval/tasks/glue.py
+0
-3
lm_eval/tasks/wino_bias.py
lm_eval/tasks/wino_bias.py
+0
-3
templates/new_task.py
templates/new_task.py
+5
-4
No files found.
lm_eval/base.py
View file @
eda365f6
...
...
@@ -375,11 +375,10 @@ class BaseLM(LM):
).
to
(
self
.
device
)
if
max_generation_length
is
None
:
max_length
=
context_enc
.
shape
[
1
]
+
self
.
max_gen_tok
s
max_length
=
self
.
max_gen_tok
else
:
max_length
=
min
(
max_generation_length
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
)
max_length
=
max_generation_length
cont
=
self
.
_model_generate
(
context_enc
,
max_length
,
...
...
@@ -595,78 +594,6 @@ class Task(abc.ABC):
)
return
""
@
utils
.
positional_deprecated
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
assert
(
rnd
is
not
None
),
"A `random.Random` generator argument must be provided to `rnd`"
assert
not
provide_description
,
(
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if
provide_description
is
not
None
:
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description
=
description
+
"
\n\n
"
if
description
else
""
if
num_fewshot
==
0
:
labeled_examples
=
""
else
:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if
self
.
has_training_docs
():
fewshotex
=
self
.
fewshot_examples
(
k
=
num_fewshot
,
rnd
=
rnd
)
else
:
if
self
.
_fewshot_docs
is
None
:
self
.
_fewshot_docs
=
list
(
self
.
validation_docs
()
if
self
.
has_validation_docs
()
else
self
.
test_docs
()
)
fewshotex
=
rnd
.
sample
(
self
.
_fewshot_docs
,
num_fewshot
+
1
)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex
=
[
x
for
x
in
fewshotex
if
x
!=
doc
][:
num_fewshot
]
# See Webson & Pavlick (2022) https://arxiv.org/pdf/2109.01247.pdf
# for justification of this separator.
example_separator
=
"
\n
###
\n
"
labeled_examples
=
(
example_separator
.
join
(
[
self
.
doc_to_text
(
doc
)
+
self
.
doc_to_target
(
doc
)
for
doc
in
fewshotex
]
)
+
example_separator
)
example
=
self
.
doc_to_text
(
doc
)
return
description
+
labeled_examples
+
example
class
PromptSourceTask
(
Task
):
"""These are the metrics from promptsource that we have
...
...
@@ -691,10 +618,12 @@ class PromptSourceTask(Task):
self
.
prompt
=
prompt
self
.
save_examples
=
save_examples
def
stopping_criteria
(
self
)
->
Optional
[
str
]:
"""Denote where the generation should end.
By default, its "
\n
###
\n
".
def
stopping_criteria
(
self
)
->
Optional
[
str
]:
"""
Denote where the generation should end based on the few-shot example
separator: "
\n
###
\n
".
TODO: Handle other separators in the future.
"""
return
"
\n
###
\n
"
...
...
lm_eval/models/gpt2.py
View file @
eda365f6
...
...
@@ -151,7 +151,7 @@ class HFLM(BaseLM):
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
,
num_fewshot
):
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
max_length
=
max_length
+
context
.
size
(
1
)
if
num_fewshot
==
0
:
generations
=
self
.
gpt2
.
generate
(
context
,
...
...
lm_eval/models/gptj.py
View file @
eda365f6
...
...
@@ -118,7 +118,7 @@ class GPTJLM(BaseLM):
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
,
num_fewshot
):
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
max_length
=
max_length
+
context
.
size
(
1
)
if
num_fewshot
==
0
:
generations
=
self
.
gptj
.
generate
(
context
,
...
...
lm_eval/models/t5.py
View file @
eda365f6
...
...
@@ -188,7 +188,6 @@ class T5LM(BaseLM):
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
,
num_fewshot
):
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
if
num_fewshot
==
0
:
generations
=
self
.
t5
.
generate
(
context
,
...
...
lm_eval/tasks/drop.py
View file @
eda365f6
...
...
@@ -92,8 +92,6 @@ class DROP(PromptSourceTask):
# """
# conts = [rf.greedy_until(ctx, ["."])]
# return conts
# def stopping_criteria(self):
# return "."
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
...
...
lm_eval/tasks/e2e_nlg_cleaned.py
View file @
eda365f6
...
...
@@ -61,9 +61,6 @@ class E2E_NLG_Cleaned(PromptSourceTask):
def
max_generation_length
(
self
):
return
64
# def stopping_criteria(self):
# return '\n\n'
def
invalid_doc_for_prompt
(
self
,
doc
)
->
bool
:
"""The QA prompts are not applicable to all the examples, we want to filter these out."""
return
self
.
prompt
.
name
.
endswith
(
"_qa"
)
or
self
.
prompt
.
name
==
"family_friendly_yes_no"
...
...
@@ -73,7 +70,7 @@ class E2E_NLG_Cleaned(PromptSourceTask):
text
=
self
.
prompt
.
apply
(
doc
)[
0
]
return
text
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
,
args
):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
...
...
@@ -90,6 +87,7 @@ class E2E_NLG_Cleaned(PromptSourceTask):
request_args
=
{
"stopping_criteria"
:
self
.
stopping_criteria
(),
"max_generation_length"
:
self
.
max_generation_length
(),
"num_fewshot"
:
args
[
"num_fewshot"
],
}
# Skip examples for which the templates are not applicable
...
...
lm_eval/tasks/gem_asset_turk.py
View file @
eda365f6
...
...
@@ -78,15 +78,9 @@ class AssetTurk(PromptSourceTask):
def
test_docs
(
self
):
return
self
.
dataset
[
str
(
self
.
SPLIT
)]
# def stopping_criteria(self):
# return None
def
max_generation_length
(
self
):
return
200
# def higher_is_better(self):
# return {"bleu": True, "rouge": True}
class
AssetTest
(
AssetTurk
):
SPLIT
=
"test_asset"
...
...
lm_eval/tasks/gem_mlsum.py
View file @
eda365f6
...
...
@@ -50,9 +50,6 @@ class GEMMLSUMEsBase(PromptSourceTask):
if
self
.
has_test_docs
():
return
self
.
dataset
[
"test"
]
def
stopping_criteria
(
self
):
return
"."
class
GEMMLSUMEs
(
GEMMLSUMEsBase
):
'''this is for train/validation/test'''
SPLIT
=
''
...
...
@@ -98,9 +95,6 @@ class GEMMLSUMDeBase(PromptSourceTask):
if
self
.
has_test_docs
():
return
self
.
dataset
[
"test"
]
def
stopping_criteria
(
self
):
return
"."
class
GEMMLSUMDe
(
GEMMLSUMDeBase
):
'''this is for train/validation/test'''
SPLIT
=
''
...
...
lm_eval/tasks/gem_webnlg.py
View file @
eda365f6
...
...
@@ -70,15 +70,9 @@ class WebNLG(PromptSourceTask):
else
:
return
self
.
dataset
[
"test"
]
# def stopping_criteria(self):
# return None
def
max_generation_length
(
self
):
return
250
# def higher_is_better(self):
# return {"bleu": True, "rouge": True}
class
WebNLGRu
(
WebNLG
):
DATASET_NAME
=
"ru"
...
...
lm_eval/tasks/gem_xsum.py
View file @
eda365f6
...
...
@@ -42,8 +42,6 @@ class GEMXSUMBase(PromptSourceTask):
def
has_test_docs
(
self
):
return
True
def
stopping_criteria
(
self
):
return
'.'
def
training_docs
(
self
):
if
self
.
has_training_docs
():
# We cache training documents in `self._training_docs` for faster
...
...
lm_eval/tasks/glue.py
View file @
eda365f6
...
...
@@ -236,9 +236,6 @@ class MRPC(PromptSourceTask):
def
has_test_docs
(
self
):
return
False
# def stopping_criteria(self):
# return "\n###\n"
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
...
...
lm_eval/tasks/wino_bias.py
View file @
eda365f6
...
...
@@ -54,9 +54,6 @@ class WinoBias(PromptSourceTask):
def
test_docs
(
self
):
return
self
.
dataset
[
"test"
]
# def stopping_criteria(self):
# return "\n"
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
...
...
templates/new_task.py
View file @
eda365f6
...
...
@@ -72,11 +72,12 @@ class NewTask(PromptSourceTask):
# named differently than the default `"test"`.
return
self
.
dataset
[
"test"
]
def
stopping_criteria
(
self
):
# Only define this method when you want to control few-shot generations on specific tokens.
# The default is set to '\n###\n'.
def
max_generation_length
(
self
):
# Define this method when you want to control the length of few-shot
# generations on specific tokens. The default is `None` which gets mapped
# to a model's default max generation token length. E.g. see `lm_eval/models/gpt2.py:max_gen_toks()`
# NOTE: You may delete this function if the task does not required generation.
return
"
\n
###
\n
"
return
None
def
construct_requests
(
self
,
doc
,
ctx
):
"""Uses RequestFactory to construct Requests and returns an iterable of
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment