"script/cmake-rocm3.7.sh" did not exist on "bbcb67d0aac81b51336981713662a726875ebd58"
Commit 4f85bcf9 authored by cjlovering's avatar cjlovering
Browse files

Updated doc: If the answer choices is empty, then it is generation; else...

Updated doc: If the answer choices is empty, then it is generation; else ranked choice. This will be the canonical approach when using PS.
parent af3cccc8
......@@ -673,12 +673,6 @@ class PromptSourceTask(Task):
"""Denote where the max length of the generation if it is obvious from the task."""
return None
def is_generation_task(self):
return (
"BLEU" in self.prompt.metadata.metrics
or "ROUGE" in self.prompt.metadata.metrics
)
def invalid_doc_for_prompt(self, doc) -> bool:
"""Some prompts may not work for some documents."""
if (
......@@ -718,18 +712,14 @@ class PromptSourceTask(Task):
_requests = []
answer_choices_list = self.prompt.get_answer_choices_list(doc)
# We take a present answer_choices list to mean that we should apply the supplied
# metrics (hardcoded or accuracy atm) to the ranked choices. Otherwise, assume generation.
# Above we do something similar, but rely on the metrics requested (BLEU, ROUGE indicating generation).
if answer_choices_list:
assert (
not self.is_generation_task()
), f"We expect this to be a ranked choice task; double check please."
# If answer_choices_list, then this is a ranked choice prompt.
for answer_choice in answer_choices_list:
ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}")
_requests.append(ll_answer_choice)
else:
# TODO(Albert): What is the stop symbol? Is it model specific?
# If not, then this is a generation prompt.
# NOTE: In the future, target will be a list of strings.
cont_request = rf.greedy_until(
ctx, [self.stopping_criteria(), self.max_generation_length()]
)
......@@ -750,9 +740,11 @@ class PromptSourceTask(Task):
target = self.doc_to_target(doc).strip()
answer_choices_list = self.prompt.get_answer_choices_list(doc)
if answer_choices_list:
assert (
not self.is_generation_task()
), f"We expect this to be a ranked choice task; double check please."
# If answer_choices_list, then this is a ranked choice prompt.
# NOTE: In the future, target will be a list of strings.
# For now, we can assume there will be only 1 target, but its possible
# that this not the case so we should check for that.
pred = answer_choices_list[np.argmax(results)]
out = {}
......@@ -765,7 +757,8 @@ class PromptSourceTask(Task):
# TODO: Add metrics here.
return out
else:
# NOTE: In the future, target may be a list, not a string.
# If not, then this is a generation prompt.
# NOTE: In the future, target will be a list of strings.
pred = results[0].strip()
out = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment