Commit eda365f6 authored by jon-tow's avatar jon-tow
Browse files

Fix max generation limit

parent 22155f7d
......@@ -375,11 +375,10 @@ class BaseLM(LM):
).to(self.device)
if max_generation_length is None:
max_length = context_enc.shape[1] + self.max_gen_toks
max_length = self.max_gen_tok
else:
max_length = min(
max_generation_length, context_enc.shape[1] + self.max_gen_toks
)
max_length = max_generation_length
cont = self._model_generate(
context_enc,
max_length,
......@@ -595,78 +594,6 @@ class Task(abc.ABC):
)
return ""
@utils.positional_deprecated
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if provide_description is not None:
# nudge people to not specify it at all
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description = description + "\n\n" if description else ""
if num_fewshot == 0:
labeled_examples = ""
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
# See Webson & Pavlick (2022) https://arxiv.org/pdf/2109.01247.pdf
# for justification of this separator.
example_separator = "\n###\n"
labeled_examples = (
example_separator.join(
[
self.doc_to_text(doc) + self.doc_to_target(doc)
for doc in fewshotex
]
)
+ example_separator
)
example = self.doc_to_text(doc)
return description + labeled_examples + example
class PromptSourceTask(Task):
"""These are the metrics from promptsource that we have
......@@ -691,10 +618,12 @@ class PromptSourceTask(Task):
self.prompt = prompt
self.save_examples = save_examples
def stopping_criteria(self) -> Optional[str]:
"""Denote where the generation should end.
By default, its "\n###\n".
def stopping_criteria(self) -> Optional[str]:
"""
Denote where the generation should end based on the few-shot example
separator: "\n###\n".
TODO: Handle other separators in the future.
"""
return "\n###\n"
......
......@@ -151,7 +151,7 @@ class HFLM(BaseLM):
def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
max_length = max_length + context.size(1)
if num_fewshot == 0:
generations = self.gpt2.generate(
context,
......
......@@ -118,7 +118,7 @@ class GPTJLM(BaseLM):
def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
max_length = max_length + context.size(1)
if num_fewshot == 0:
generations = self.gptj.generate(
context,
......
......@@ -188,7 +188,6 @@ class T5LM(BaseLM):
def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
if num_fewshot == 0:
generations = self.t5.generate(
context,
......
......@@ -92,8 +92,6 @@ class DROP(PromptSourceTask):
# """
# conts = [rf.greedy_until(ctx, ["."])]
# return conts
# def stopping_criteria(self):
# return "."
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......
......@@ -61,9 +61,6 @@ class E2E_NLG_Cleaned(PromptSourceTask):
def max_generation_length(self):
return 64
# def stopping_criteria(self):
# return '\n\n'
def invalid_doc_for_prompt(self, doc) -> bool:
"""The QA prompts are not applicable to all the examples, we want to filter these out."""
return self.prompt.name.endswith("_qa") or self.prompt.name == "family_friendly_yes_no"
......@@ -73,7 +70,7 @@ class E2E_NLG_Cleaned(PromptSourceTask):
text = self.prompt.apply(doc)[0]
return text
def construct_requests(self, doc, ctx):
def construct_requests(self, doc, ctx, args):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
......@@ -90,6 +87,7 @@ class E2E_NLG_Cleaned(PromptSourceTask):
request_args = {
"stopping_criteria": self.stopping_criteria(),
"max_generation_length": self.max_generation_length(),
"num_fewshot": args["num_fewshot"],
}
# Skip examples for which the templates are not applicable
......
......@@ -78,15 +78,9 @@ class AssetTurk(PromptSourceTask):
def test_docs(self):
return self.dataset[str(self.SPLIT)]
# def stopping_criteria(self):
# return None
def max_generation_length(self):
return 200
# def higher_is_better(self):
# return {"bleu": True, "rouge": True}
class AssetTest(AssetTurk):
SPLIT = "test_asset"
......
......@@ -50,9 +50,6 @@ class GEMMLSUMEsBase(PromptSourceTask):
if self.has_test_docs():
return self.dataset["test"]
def stopping_criteria(self):
return "."
class GEMMLSUMEs(GEMMLSUMEsBase):
'''this is for train/validation/test'''
SPLIT = ''
......@@ -98,9 +95,6 @@ class GEMMLSUMDeBase(PromptSourceTask):
if self.has_test_docs():
return self.dataset["test"]
def stopping_criteria(self):
return "."
class GEMMLSUMDe(GEMMLSUMDeBase):
'''this is for train/validation/test'''
SPLIT = ''
......
......@@ -70,15 +70,9 @@ class WebNLG(PromptSourceTask):
else:
return self.dataset["test"]
# def stopping_criteria(self):
# return None
def max_generation_length(self):
return 250
# def higher_is_better(self):
# return {"bleu": True, "rouge": True}
class WebNLGRu(WebNLG):
DATASET_NAME = "ru"
......
......@@ -42,8 +42,6 @@ class GEMXSUMBase(PromptSourceTask):
def has_test_docs(self):
return True
def stopping_criteria(self):
return '.'
def training_docs(self):
if self.has_training_docs():
# We cache training documents in `self._training_docs` for faster
......
......@@ -236,9 +236,6 @@ class MRPC(PromptSourceTask):
def has_test_docs(self):
return False
# def stopping_criteria(self):
# return "\n###\n"
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
......
......@@ -54,9 +54,6 @@ class WinoBias(PromptSourceTask):
def test_docs(self):
return self.dataset["test"]
# def stopping_criteria(self):
# return "\n"
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
......
......@@ -72,11 +72,12 @@ class NewTask(PromptSourceTask):
# named differently than the default `"test"`.
return self.dataset["test"]
def stopping_criteria(self):
# Only define this method when you want to control few-shot generations on specific tokens.
# The default is set to '\n###\n'.
def max_generation_length(self):
# Define this method when you want to control the length of few-shot
# generations on specific tokens. The default is `None` which gets mapped
# to a model's default max generation token length. E.g. see `lm_eval/models/gpt2.py:max_gen_toks()`
# NOTE: You may delete this function if the task does not required generation.
return "\n###\n"
return None
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment