Commit eda365f6 authored by jon-tow's avatar jon-tow
Browse files

Fix max generation limit

parent 22155f7d
...@@ -375,11 +375,10 @@ class BaseLM(LM): ...@@ -375,11 +375,10 @@ class BaseLM(LM):
).to(self.device) ).to(self.device)
if max_generation_length is None: if max_generation_length is None:
max_length = context_enc.shape[1] + self.max_gen_toks max_length = self.max_gen_tok
else: else:
max_length = min( max_length = max_generation_length
max_generation_length, context_enc.shape[1] + self.max_gen_toks
)
cont = self._model_generate( cont = self._model_generate(
context_enc, context_enc,
max_length, max_length,
...@@ -595,78 +594,6 @@ class Task(abc.ABC): ...@@ -595,78 +594,6 @@ class Task(abc.ABC):
) )
return "" return ""
@utils.positional_deprecated
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if provide_description is not None:
# nudge people to not specify it at all
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description = description + "\n\n" if description else ""
if num_fewshot == 0:
labeled_examples = ""
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
# See Webson & Pavlick (2022) https://arxiv.org/pdf/2109.01247.pdf
# for justification of this separator.
example_separator = "\n###\n"
labeled_examples = (
example_separator.join(
[
self.doc_to_text(doc) + self.doc_to_target(doc)
for doc in fewshotex
]
)
+ example_separator
)
example = self.doc_to_text(doc)
return description + labeled_examples + example
class PromptSourceTask(Task): class PromptSourceTask(Task):
"""These are the metrics from promptsource that we have """These are the metrics from promptsource that we have
...@@ -691,10 +618,12 @@ class PromptSourceTask(Task): ...@@ -691,10 +618,12 @@ class PromptSourceTask(Task):
self.prompt = prompt self.prompt = prompt
self.save_examples = save_examples self.save_examples = save_examples
def stopping_criteria(self) -> Optional[str]:
"""Denote where the generation should end.
By default, its "\n###\n". def stopping_criteria(self) -> Optional[str]:
"""
Denote where the generation should end based on the few-shot example
separator: "\n###\n".
TODO: Handle other separators in the future.
""" """
return "\n###\n" return "\n###\n"
......
...@@ -151,7 +151,7 @@ class HFLM(BaseLM): ...@@ -151,7 +151,7 @@ class HFLM(BaseLM):
def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot): def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
max_length = max_length + context.size(1)
if num_fewshot == 0: if num_fewshot == 0:
generations = self.gpt2.generate( generations = self.gpt2.generate(
context, context,
......
...@@ -118,7 +118,7 @@ class GPTJLM(BaseLM): ...@@ -118,7 +118,7 @@ class GPTJLM(BaseLM):
def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot): def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
max_length = max_length + context.size(1)
if num_fewshot == 0: if num_fewshot == 0:
generations = self.gptj.generate( generations = self.gptj.generate(
context, context,
......
...@@ -188,7 +188,6 @@ class T5LM(BaseLM): ...@@ -188,7 +188,6 @@ class T5LM(BaseLM):
def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot): def _model_generate(self, context, max_length, stopping_criteria_ids, num_fewshot):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids) stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
if num_fewshot == 0: if num_fewshot == 0:
generations = self.t5.generate( generations = self.t5.generate(
context, context,
......
...@@ -92,8 +92,6 @@ class DROP(PromptSourceTask): ...@@ -92,8 +92,6 @@ class DROP(PromptSourceTask):
# """ # """
# conts = [rf.greedy_until(ctx, ["."])] # conts = [rf.greedy_until(ctx, ["."])]
# return conts # return conts
# def stopping_criteria(self):
# return "."
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
......
...@@ -61,9 +61,6 @@ class E2E_NLG_Cleaned(PromptSourceTask): ...@@ -61,9 +61,6 @@ class E2E_NLG_Cleaned(PromptSourceTask):
def max_generation_length(self): def max_generation_length(self):
return 64 return 64
# def stopping_criteria(self):
# return '\n\n'
def invalid_doc_for_prompt(self, doc) -> bool: def invalid_doc_for_prompt(self, doc) -> bool:
"""The QA prompts are not applicable to all the examples, we want to filter these out.""" """The QA prompts are not applicable to all the examples, we want to filter these out."""
return self.prompt.name.endswith("_qa") or self.prompt.name == "family_friendly_yes_no" return self.prompt.name.endswith("_qa") or self.prompt.name == "family_friendly_yes_no"
...@@ -73,7 +70,7 @@ class E2E_NLG_Cleaned(PromptSourceTask): ...@@ -73,7 +70,7 @@ class E2E_NLG_Cleaned(PromptSourceTask):
text = self.prompt.apply(doc)[0] text = self.prompt.apply(doc)[0]
return text return text
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx, args):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
...@@ -90,6 +87,7 @@ class E2E_NLG_Cleaned(PromptSourceTask): ...@@ -90,6 +87,7 @@ class E2E_NLG_Cleaned(PromptSourceTask):
request_args = { request_args = {
"stopping_criteria": self.stopping_criteria(), "stopping_criteria": self.stopping_criteria(),
"max_generation_length": self.max_generation_length(), "max_generation_length": self.max_generation_length(),
"num_fewshot": args["num_fewshot"],
} }
# Skip examples for which the templates are not applicable # Skip examples for which the templates are not applicable
......
...@@ -78,15 +78,9 @@ class AssetTurk(PromptSourceTask): ...@@ -78,15 +78,9 @@ class AssetTurk(PromptSourceTask):
def test_docs(self): def test_docs(self):
return self.dataset[str(self.SPLIT)] return self.dataset[str(self.SPLIT)]
# def stopping_criteria(self):
# return None
def max_generation_length(self): def max_generation_length(self):
return 200 return 200
# def higher_is_better(self):
# return {"bleu": True, "rouge": True}
class AssetTest(AssetTurk): class AssetTest(AssetTurk):
SPLIT = "test_asset" SPLIT = "test_asset"
......
...@@ -50,9 +50,6 @@ class GEMMLSUMEsBase(PromptSourceTask): ...@@ -50,9 +50,6 @@ class GEMMLSUMEsBase(PromptSourceTask):
if self.has_test_docs(): if self.has_test_docs():
return self.dataset["test"] return self.dataset["test"]
def stopping_criteria(self):
return "."
class GEMMLSUMEs(GEMMLSUMEsBase): class GEMMLSUMEs(GEMMLSUMEsBase):
'''this is for train/validation/test''' '''this is for train/validation/test'''
SPLIT = '' SPLIT = ''
...@@ -98,9 +95,6 @@ class GEMMLSUMDeBase(PromptSourceTask): ...@@ -98,9 +95,6 @@ class GEMMLSUMDeBase(PromptSourceTask):
if self.has_test_docs(): if self.has_test_docs():
return self.dataset["test"] return self.dataset["test"]
def stopping_criteria(self):
return "."
class GEMMLSUMDe(GEMMLSUMDeBase): class GEMMLSUMDe(GEMMLSUMDeBase):
'''this is for train/validation/test''' '''this is for train/validation/test'''
SPLIT = '' SPLIT = ''
......
...@@ -70,15 +70,9 @@ class WebNLG(PromptSourceTask): ...@@ -70,15 +70,9 @@ class WebNLG(PromptSourceTask):
else: else:
return self.dataset["test"] return self.dataset["test"]
# def stopping_criteria(self):
# return None
def max_generation_length(self): def max_generation_length(self):
return 250 return 250
# def higher_is_better(self):
# return {"bleu": True, "rouge": True}
class WebNLGRu(WebNLG): class WebNLGRu(WebNLG):
DATASET_NAME = "ru" DATASET_NAME = "ru"
......
...@@ -42,8 +42,6 @@ class GEMXSUMBase(PromptSourceTask): ...@@ -42,8 +42,6 @@ class GEMXSUMBase(PromptSourceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def stopping_criteria(self):
return '.'
def training_docs(self): def training_docs(self):
if self.has_training_docs(): if self.has_training_docs():
# We cache training documents in `self._training_docs` for faster # We cache training documents in `self._training_docs` for faster
......
...@@ -236,9 +236,6 @@ class MRPC(PromptSourceTask): ...@@ -236,9 +236,6 @@ class MRPC(PromptSourceTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
# def stopping_criteria(self):
# return "\n###\n"
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.dataset["train"]) self._training_docs = list(self.dataset["train"])
......
...@@ -54,9 +54,6 @@ class WinoBias(PromptSourceTask): ...@@ -54,9 +54,6 @@ class WinoBias(PromptSourceTask):
def test_docs(self): def test_docs(self):
return self.dataset["test"] return self.dataset["test"]
# def stopping_criteria(self):
# return "\n"
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
......
...@@ -72,11 +72,12 @@ class NewTask(PromptSourceTask): ...@@ -72,11 +72,12 @@ class NewTask(PromptSourceTask):
# named differently than the default `"test"`. # named differently than the default `"test"`.
return self.dataset["test"] return self.dataset["test"]
def stopping_criteria(self): def max_generation_length(self):
# Only define this method when you want to control few-shot generations on specific tokens. # Define this method when you want to control the length of few-shot
# The default is set to '\n###\n'. # generations on specific tokens. The default is `None` which gets mapped
# to a model's default max generation token length. E.g. see `lm_eval/models/gpt2.py:max_gen_toks()`
# NOTE: You may delete this function if the task does not required generation. # NOTE: You may delete this function if the task does not required generation.
return "\n###\n" return None
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment