Commit 1dcca55c authored by cjlovering's avatar cjlovering
Browse files

Minor updates to documentation.

parent d4c00093
......@@ -665,34 +665,28 @@ class PromptSourceTask(Task):
or "ROUGE" in self.prompt.metadata.metrics
)
def invalid_doc_for_prompt(self, doc):
"""Some prompts may not work for some documents.
As of now, we skip particular prompts, s.t. we don't
overskip. If this turns out to be a problem for many prompts
we can instead make sure that apply returns 2 things.
"""
def invalid_doc_for_prompt(self, doc) -> bool:
"""Some prompts may not work for some documents."""
if (
# generate_paraphrase for mrpc
# This generation prompt assumes a positive example. We filter out the negative examples.
# https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L7
# https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L88
(
self.prompt.id == "3b88d2c4-0aeb-4c6d-9ccc-653a388250a5"
or self.prompt.id == "d830d7a5-abc0-4275-ac62-974e0088876f"
)
and doc["label"] == 0
):
# This generation prompt assumes a positive example. We filter out the negative examples.
# https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L7
# https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L88
return True
return False
def doc_to_target(self, doc):
def doc_to_target(self, doc) -> str:
"""NOTE: In the future, this may return Union[str, List[str]]."""
_, target = self.prompt.apply(doc)
return f" {target}"
def doc_to_text(self, doc):
def doc_to_text(self, doc) -> str:
text, _ = self.prompt.apply(doc)
return text
......@@ -737,9 +731,6 @@ class PromptSourceTask(Task):
:param results:
The results of the requests created in construct_requests.
"""
# raise NotImplementedError(
# "Implement process results using the `prompt.metadata.metrics`. See below."
# )
target = self.doc_to_target(doc).strip()
answer_choices_list = self.prompt.get_answer_choices_list(doc)
if answer_choices_list:
......@@ -772,9 +763,6 @@ class PromptSourceTask(Task):
print("WARNING: Skipping Rouge.")
return out
# Map metric name to HF metric.
# TODO(Albert): What is Other?
# metric_names = prompt.metadata.metrics
def higher_is_better(self):
out = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment