Commit b2838b8d authored by cjlovering's avatar cjlovering
Browse files

Rename task specific to

parent 1dcca55c
......@@ -652,10 +652,12 @@ class PromptSourceTask(Task):
super().__init__(data_dir, cache_dir, download_mode)
self.prompt = prompt
def end_of_generation_sequence(self):
"""Denote where the generation should be split.
def stopping_criteria(self):
"""Denote where the generation should end.
For example, for coqa, this is '\nQ:' and for drop '.'.
By default, its None, meaning to generate up to max or EOT, whichever comes first.
"""
return None
......@@ -716,7 +718,7 @@ class PromptSourceTask(Task):
_requests.append(ll_answer_choice)
else:
# TODO(Albert): What is the stop symbol? Is it model specific?
cont_request = rf.greedy_until(ctx, [self.end_of_generation_sequence()])
cont_request = rf.greedy_until(ctx, [self.stopping_criteria()])
_requests.append(cont_request)
return _requests
......
......@@ -90,7 +90,7 @@ class CoQA(PromptSourceTask):
"f1": f1_sum / max(1, len(gold_list)),
}
def end_of_generation_sequence(self):
def stopping_criteria(self):
return "\nQ:"
# def construct_requests(self, doc, ctx):
......
......@@ -92,7 +92,7 @@ class DROP(PromptSourceTask):
# """
# conts = [rf.greedy_until(ctx, ["."])]
# return conts
def end_of_generation_sequence(self):
def stopping_criteria(self):
return "."
def process_results(self, doc, results):
......
......@@ -236,7 +236,7 @@ class MRPC(PromptSourceTask):
def has_test_docs(self):
return False
def end_of_generation_sequence(self):
def stopping_criteria(self):
return "\n"
def training_docs(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment