Commit f39c27c2 authored by cjlovering's avatar cjlovering
Browse files

Rename task specific to

parent c27e29e1
......@@ -641,8 +641,12 @@ class PromptSourceTask(Task):
super().__init__(data_dir, cache_dir, download_mode)
self.prompt = prompt
def eos_token(self):
raise NotImplementedError()
def end_of_generation_sequence(self):
"""Denote where the generation should be split.
For example, for coqa, this is '\nQ:' and for drop '.'.
"""
return None
def is_generation_task(self):
return (
......@@ -650,6 +654,29 @@ class PromptSourceTask(Task):
or "ROUGE" in self.prompt.metadata.metrics
)
def invalid_doc_for_prompt(self, doc):
"""Some prompts may not work for some documents.
As of now, we skip particular prompts, s.t. we don't
overskip. If this turns out to be a problem for many prompts
we can instead make sure that apply returns 2 things.
"""
if (
# generate_paraphrase for mrpc
(
self.prompt.id == "3b88d2c4-0aeb-4c6d-9ccc-653a388250a5"
or self.prompt.id == "d830d7a5-abc0-4275-ac62-974e0088876f"
)
and doc["label"] == 0
):
# This generation prompt assumes a positive example. We filter out the negative examples.
# https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L7
# https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L88
return True
return False
def doc_to_target(self, doc):
_, target = self.prompt.apply(doc)
return f" {target}"
......@@ -684,7 +711,7 @@ class PromptSourceTask(Task):
_requests.append(ll_answer_choice)
else:
# TODO(Albert): What is the stop symbol? Is it model specific?
cont_request = rf.greedy_until(ctx, [self.eos_token()])
cont_request = rf.greedy_until(ctx, [self.end_of_generation_sequence()])
_requests.append(cont_request)
return _requests
......
......@@ -2,6 +2,7 @@ import collections
import itertools
import pathlib
import random
import lm_eval.metrics
import lm_eval.models
import lm_eval.tasks
......@@ -199,6 +200,9 @@ def evaluate(
)
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
if task.invalid_doc_for_prompt(doc):
continue
docs[(task_prompt_name, doc_id)] = doc
ctx = task.fewshot_context(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
......
......@@ -118,7 +118,10 @@ class HFLM(BaseLM):
def _model_generate(self, context, max_length, eos_token_id):
return self.gpt2.generate(
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
context,
max_length=max_length,
eos_token_id=eos_token_id,
do_sample=False,
)
......
......@@ -90,7 +90,7 @@ class CoQA(PromptSourceTask):
"f1": f1_sum / max(1, len(gold_list)),
}
def eos_token(self):
def end_of_generation_sequence(self):
return "\nQ:"
# def construct_requests(self, doc, ctx):
......
......@@ -92,7 +92,7 @@ class DROP(PromptSourceTask):
# """
# conts = [rf.greedy_until(ctx, ["."])]
# return conts
def eos_token(self):
def end_of_generation_sequence(self):
return "."
def process_results(self, doc, results):
......
......@@ -236,6 +236,9 @@ class MRPC(PromptSourceTask):
def has_test_docs(self):
return False
def end_of_generation_sequence(self):
return "\n"
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment