Commit 4f85bcf9 authored by cjlovering's avatar cjlovering
Browse files

Updated doc: If the answer choices is empty, then it is generation; else...

Updated doc: If the answer choices is empty, then it is generation; else ranked choice. This will be the canonical approach when using PS.
parent af3cccc8
...@@ -673,12 +673,6 @@ class PromptSourceTask(Task): ...@@ -673,12 +673,6 @@ class PromptSourceTask(Task):
"""Denote where the max length of the generation if it is obvious from the task.""" """Denote where the max length of the generation if it is obvious from the task."""
return None return None
def is_generation_task(self):
return (
"BLEU" in self.prompt.metadata.metrics
or "ROUGE" in self.prompt.metadata.metrics
)
def invalid_doc_for_prompt(self, doc) -> bool: def invalid_doc_for_prompt(self, doc) -> bool:
"""Some prompts may not work for some documents.""" """Some prompts may not work for some documents."""
if ( if (
...@@ -718,18 +712,14 @@ class PromptSourceTask(Task): ...@@ -718,18 +712,14 @@ class PromptSourceTask(Task):
_requests = [] _requests = []
answer_choices_list = self.prompt.get_answer_choices_list(doc) answer_choices_list = self.prompt.get_answer_choices_list(doc)
# We take a present answer_choices list to mean that we should apply the supplied
# metrics (hardcoded or accuracy atm) to the ranked choices. Otherwise, assume generation.
# Above we do something similar, but rely on the metrics requested (BLEU, ROUGE indicating generation).
if answer_choices_list: if answer_choices_list:
assert ( # If answer_choices_list, then this is a ranked choice prompt.
not self.is_generation_task()
), f"We expect this to be a ranked choice task; double check please."
for answer_choice in answer_choices_list: for answer_choice in answer_choices_list:
ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}") ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}")
_requests.append(ll_answer_choice) _requests.append(ll_answer_choice)
else: else:
# TODO(Albert): What is the stop symbol? Is it model specific? # If not, then this is a generation prompt.
# NOTE: In the future, target will be a list of strings.
cont_request = rf.greedy_until( cont_request = rf.greedy_until(
ctx, [self.stopping_criteria(), self.max_generation_length()] ctx, [self.stopping_criteria(), self.max_generation_length()]
) )
...@@ -750,9 +740,11 @@ class PromptSourceTask(Task): ...@@ -750,9 +740,11 @@ class PromptSourceTask(Task):
target = self.doc_to_target(doc).strip() target = self.doc_to_target(doc).strip()
answer_choices_list = self.prompt.get_answer_choices_list(doc) answer_choices_list = self.prompt.get_answer_choices_list(doc)
if answer_choices_list: if answer_choices_list:
assert ( # If answer_choices_list, then this is a ranked choice prompt.
not self.is_generation_task() # NOTE: In the future, target will be a list of strings.
), f"We expect this to be a ranked choice task; double check please." # For now, we can assume there will be only 1 target, but its possible
# that this not the case so we should check for that.
pred = answer_choices_list[np.argmax(results)] pred = answer_choices_list[np.argmax(results)]
out = {} out = {}
...@@ -765,7 +757,8 @@ class PromptSourceTask(Task): ...@@ -765,7 +757,8 @@ class PromptSourceTask(Task):
# TODO: Add metrics here. # TODO: Add metrics here.
return out return out
else: else:
# NOTE: In the future, target may be a list, not a string. # If not, then this is a generation prompt.
# NOTE: In the future, target will be a list of strings.
pred = results[0].strip() pred = results[0].strip()
out = {} out = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment