Commit b62d1bec authored by Tian Yun's avatar Tian Yun
Browse files

Update stopping criteria for few-shot

parent 5e59320b
...@@ -694,11 +694,9 @@ class PromptSourceTask(Task): ...@@ -694,11 +694,9 @@ class PromptSourceTask(Task):
def stopping_criteria(self) -> Optional[str]: def stopping_criteria(self) -> Optional[str]:
"""Denote where the generation should end. """Denote where the generation should end.
For example, for coqa, this is '\nQ:' and for drop '.'. By default, its "\n###\n".
By default, its None, meaning to generate up to max or EOT, whichever comes first.
""" """
return None return "\n###\n"
def max_generation_length(self) -> Optional[int]: def max_generation_length(self) -> Optional[int]:
"""Denote where the max length of the generation if it is obvious from the task.""" """Denote where the max length of the generation if it is obvious from the task."""
......
...@@ -90,8 +90,8 @@ class CoQA(PromptSourceTask): ...@@ -90,8 +90,8 @@ class CoQA(PromptSourceTask):
"f1": f1_sum / max(1, len(gold_list)), "f1": f1_sum / max(1, len(gold_list)),
} }
def stopping_criteria(self): # def stopping_criteria(self):
return "\n\n" # return "\n\n"
# def construct_requests(self, doc, ctx): # def construct_requests(self, doc, ctx):
# """Uses RequestFactory to construct Requests and returns an iterable of # """Uses RequestFactory to construct Requests and returns an iterable of
......
...@@ -92,8 +92,8 @@ class DROP(PromptSourceTask): ...@@ -92,8 +92,8 @@ class DROP(PromptSourceTask):
# """ # """
# conts = [rf.greedy_until(ctx, ["."])] # conts = [rf.greedy_until(ctx, ["."])]
# return conts # return conts
def stopping_criteria(self): # def stopping_criteria(self):
return "." # return "."
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
......
...@@ -78,8 +78,8 @@ class AssetTurk(PromptSourceTask): ...@@ -78,8 +78,8 @@ class AssetTurk(PromptSourceTask):
def test_docs(self): def test_docs(self):
return self.dataset[str(self.SPLIT)] return self.dataset[str(self.SPLIT)]
def stopping_criteria(self): # def stopping_criteria(self):
return None # return None
def max_generation_length(self): def max_generation_length(self):
return 200 return 200
......
...@@ -70,8 +70,8 @@ class WebNLG(PromptSourceTask): ...@@ -70,8 +70,8 @@ class WebNLG(PromptSourceTask):
else: else:
return self.dataset["test"] return self.dataset["test"]
def stopping_criteria(self): # def stopping_criteria(self):
return None # return None
def max_generation_length(self): def max_generation_length(self):
return 250 return 250
......
...@@ -236,8 +236,8 @@ class MRPC(PromptSourceTask): ...@@ -236,8 +236,8 @@ class MRPC(PromptSourceTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def stopping_criteria(self): # def stopping_criteria(self):
return "\n" # return "\n###\n"
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
......
...@@ -54,8 +54,8 @@ class WinoBias(PromptSourceTask): ...@@ -54,8 +54,8 @@ class WinoBias(PromptSourceTask):
def test_docs(self): def test_docs(self):
return self.dataset["test"] return self.dataset["test"]
def stopping_criteria(self): # def stopping_criteria(self):
return "\n" # return "\n"
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
......
...@@ -73,10 +73,10 @@ class NewTask(PromptSourceTask): ...@@ -73,10 +73,10 @@ class NewTask(PromptSourceTask):
return self.dataset["test"] return self.dataset["test"]
def stopping_criteria(self): def stopping_criteria(self):
# TODO: Denote the string where the generation should be split. # Only define this method when you want to control few-shot generations on specific tokens.
# For example, for `coqa`, this is '\nQ:' and for `drop` '.'. # The default is set to '\n###\n'.
# NOTE: You may delete this function if the task does not required generation. # NOTE: You may delete this function if the task does not required generation.
return None return "\n###\n"
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment