Commit b62d1bec authored by Tian Yun's avatar Tian Yun
Browse files

Update stopping criteria for few-shot

parent 5e59320b
......@@ -694,11 +694,9 @@ class PromptSourceTask(Task):
def stopping_criteria(self) -> Optional[str]:
"""Denote where the generation should end.
For example, for coqa, this is '\nQ:' and for drop '.'.
By default, its None, meaning to generate up to max or EOT, whichever comes first.
By default, its "\n###\n".
"""
return None
return "\n###\n"
def max_generation_length(self) -> Optional[int]:
"""Denote where the max length of the generation if it is obvious from the task."""
......
......@@ -90,8 +90,8 @@ class CoQA(PromptSourceTask):
"f1": f1_sum / max(1, len(gold_list)),
}
def stopping_criteria(self):
return "\n\n"
# def stopping_criteria(self):
# return "\n\n"
# def construct_requests(self, doc, ctx):
# """Uses RequestFactory to construct Requests and returns an iterable of
......
......@@ -92,8 +92,8 @@ class DROP(PromptSourceTask):
# """
# conts = [rf.greedy_until(ctx, ["."])]
# return conts
def stopping_criteria(self):
return "."
# def stopping_criteria(self):
# return "."
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......
......@@ -78,8 +78,8 @@ class AssetTurk(PromptSourceTask):
def test_docs(self):
return self.dataset[str(self.SPLIT)]
def stopping_criteria(self):
return None
# def stopping_criteria(self):
# return None
def max_generation_length(self):
return 200
......
......@@ -70,8 +70,8 @@ class WebNLG(PromptSourceTask):
else:
return self.dataset["test"]
def stopping_criteria(self):
return None
# def stopping_criteria(self):
# return None
def max_generation_length(self):
return 250
......
......@@ -236,8 +236,8 @@ class MRPC(PromptSourceTask):
def has_test_docs(self):
return False
def stopping_criteria(self):
return "\n"
# def stopping_criteria(self):
# return "\n###\n"
def training_docs(self):
if self._training_docs is None:
......
......@@ -54,8 +54,8 @@ class WinoBias(PromptSourceTask):
def test_docs(self):
return self.dataset["test"]
def stopping_criteria(self):
return "\n"
# def stopping_criteria(self):
# return "\n"
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......
......@@ -73,10 +73,10 @@ class NewTask(PromptSourceTask):
return self.dataset["test"]
def stopping_criteria(self):
# TODO: Denote the string where the generation should be split.
# For example, for `coqa`, this is '\nQ:' and for `drop` '.'.
# Only define this method when you want to control few-shot generations on specific tokens.
# The default is set to '\n###\n'.
# NOTE: You may delete this function if the task does not required generation.
return None
return "\n###\n"
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment