Commit b2838b8d authored by cjlovering's avatar cjlovering
Browse files

Rename task specific to

parent 1dcca55c
...@@ -652,10 +652,12 @@ class PromptSourceTask(Task): ...@@ -652,10 +652,12 @@ class PromptSourceTask(Task):
super().__init__(data_dir, cache_dir, download_mode) super().__init__(data_dir, cache_dir, download_mode)
self.prompt = prompt self.prompt = prompt
def end_of_generation_sequence(self): def stopping_criteria(self):
"""Denote where the generation should be split. """Denote where the generation should end.
For example, for coqa, this is '\nQ:' and for drop '.'. For example, for coqa, this is '\nQ:' and for drop '.'.
By default, its None, meaning to generate up to max or EOT, whichever comes first.
""" """
return None return None
...@@ -716,7 +718,7 @@ class PromptSourceTask(Task): ...@@ -716,7 +718,7 @@ class PromptSourceTask(Task):
_requests.append(ll_answer_choice) _requests.append(ll_answer_choice)
else: else:
# TODO(Albert): What is the stop symbol? Is it model specific? # TODO(Albert): What is the stop symbol? Is it model specific?
cont_request = rf.greedy_until(ctx, [self.end_of_generation_sequence()]) cont_request = rf.greedy_until(ctx, [self.stopping_criteria()])
_requests.append(cont_request) _requests.append(cont_request)
return _requests return _requests
......
...@@ -90,7 +90,7 @@ class CoQA(PromptSourceTask): ...@@ -90,7 +90,7 @@ class CoQA(PromptSourceTask):
"f1": f1_sum / max(1, len(gold_list)), "f1": f1_sum / max(1, len(gold_list)),
} }
def end_of_generation_sequence(self): def stopping_criteria(self):
return "\nQ:" return "\nQ:"
# def construct_requests(self, doc, ctx): # def construct_requests(self, doc, ctx):
......
...@@ -92,7 +92,7 @@ class DROP(PromptSourceTask): ...@@ -92,7 +92,7 @@ class DROP(PromptSourceTask):
# """ # """
# conts = [rf.greedy_until(ctx, ["."])] # conts = [rf.greedy_until(ctx, ["."])]
# return conts # return conts
def end_of_generation_sequence(self): def stopping_criteria(self):
return "." return "."
def process_results(self, doc, results): def process_results(self, doc, results):
......
...@@ -236,7 +236,7 @@ class MRPC(PromptSourceTask): ...@@ -236,7 +236,7 @@ class MRPC(PromptSourceTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def end_of_generation_sequence(self): def stopping_criteria(self):
return "\n" return "\n"
def training_docs(self): def training_docs(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment