Commit f39c27c2 authored by cjlovering's avatar cjlovering
Browse files

Rename task specific to

parent c27e29e1
...@@ -641,8 +641,12 @@ class PromptSourceTask(Task): ...@@ -641,8 +641,12 @@ class PromptSourceTask(Task):
super().__init__(data_dir, cache_dir, download_mode) super().__init__(data_dir, cache_dir, download_mode)
self.prompt = prompt self.prompt = prompt
def eos_token(self): def end_of_generation_sequence(self):
raise NotImplementedError() """Denote where the generation should be split.
For example, for coqa, this is '\nQ:' and for drop '.'.
"""
return None
def is_generation_task(self): def is_generation_task(self):
return ( return (
...@@ -650,6 +654,29 @@ class PromptSourceTask(Task): ...@@ -650,6 +654,29 @@ class PromptSourceTask(Task):
or "ROUGE" in self.prompt.metadata.metrics or "ROUGE" in self.prompt.metadata.metrics
) )
def invalid_doc_for_prompt(self, doc):
"""Some prompts may not work for some documents.
As of now, we skip particular prompts, s.t. we don't
overskip. If this turns out to be a problem for many prompts
we can instead make sure that apply returns 2 things.
"""
if (
# generate_paraphrase for mrpc
(
self.prompt.id == "3b88d2c4-0aeb-4c6d-9ccc-653a388250a5"
or self.prompt.id == "d830d7a5-abc0-4275-ac62-974e0088876f"
)
and doc["label"] == 0
):
# This generation prompt assumes a positive example. We filter out the negative examples.
# https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L7
# https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L88
return True
return False
def doc_to_target(self, doc): def doc_to_target(self, doc):
_, target = self.prompt.apply(doc) _, target = self.prompt.apply(doc)
return f" {target}" return f" {target}"
...@@ -684,7 +711,7 @@ class PromptSourceTask(Task): ...@@ -684,7 +711,7 @@ class PromptSourceTask(Task):
_requests.append(ll_answer_choice) _requests.append(ll_answer_choice)
else: else:
# TODO(Albert): What is the stop symbol? Is it model specific? # TODO(Albert): What is the stop symbol? Is it model specific?
cont_request = rf.greedy_until(ctx, [self.eos_token()]) cont_request = rf.greedy_until(ctx, [self.end_of_generation_sequence()])
_requests.append(cont_request) _requests.append(cont_request)
return _requests return _requests
......
...@@ -2,6 +2,7 @@ import collections ...@@ -2,6 +2,7 @@ import collections
import itertools import itertools
import pathlib import pathlib
import random import random
import lm_eval.metrics import lm_eval.metrics
import lm_eval.models import lm_eval.models
import lm_eval.tasks import lm_eval.tasks
...@@ -199,6 +200,9 @@ def evaluate( ...@@ -199,6 +200,9 @@ def evaluate(
) )
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
if task.invalid_doc_for_prompt(doc):
continue
docs[(task_prompt_name, doc_id)] = doc docs[(task_prompt_name, doc_id)] = doc
ctx = task.fewshot_context( ctx = task.fewshot_context(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
......
...@@ -118,7 +118,10 @@ class HFLM(BaseLM): ...@@ -118,7 +118,10 @@ class HFLM(BaseLM):
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
return self.gpt2.generate( return self.gpt2.generate(
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False context,
max_length=max_length,
eos_token_id=eos_token_id,
do_sample=False,
) )
......
...@@ -90,7 +90,7 @@ class CoQA(PromptSourceTask): ...@@ -90,7 +90,7 @@ class CoQA(PromptSourceTask):
"f1": f1_sum / max(1, len(gold_list)), "f1": f1_sum / max(1, len(gold_list)),
} }
def eos_token(self): def end_of_generation_sequence(self):
return "\nQ:" return "\nQ:"
# def construct_requests(self, doc, ctx): # def construct_requests(self, doc, ctx):
......
...@@ -92,7 +92,7 @@ class DROP(PromptSourceTask): ...@@ -92,7 +92,7 @@ class DROP(PromptSourceTask):
# """ # """
# conts = [rf.greedy_until(ctx, ["."])] # conts = [rf.greedy_until(ctx, ["."])]
# return conts # return conts
def eos_token(self): def end_of_generation_sequence(self):
return "." return "."
def process_results(self, doc, results): def process_results(self, doc, results):
......
...@@ -236,6 +236,9 @@ class MRPC(PromptSourceTask): ...@@ -236,6 +236,9 @@ class MRPC(PromptSourceTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def end_of_generation_sequence(self):
return "\n"
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.dataset["train"]) self._training_docs = list(self.dataset["train"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment