Commit 0b8cb8b9 authored by Tian Yun's avatar Tian Yun
Browse files

Merge with master branch

parents 96ea7ddc b2838b8d
...@@ -14,6 +14,7 @@ from tqdm import tqdm ...@@ -14,6 +14,7 @@ from tqdm import tqdm
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from lm_eval import metrics
from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
from lm_eval import utils from lm_eval import utils
from abc import abstractmethod from abc import abstractmethod
...@@ -637,12 +638,28 @@ class Task(abc.ABC): ...@@ -637,12 +638,28 @@ class Task(abc.ABC):
class PromptSourceTask(Task): class PromptSourceTask(Task):
"""These are the metrics from promptsource that we have
added default behavior for. If you want to add default behavior for a new metric,
update the functions below. If you want to use one of the following metrics,
*and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`.
WARNING: ROUGE is WIP.
"""
CONFIGURED_PS_METRICS = set(["Accuracy", "BLEU", "ROUGE"])
def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=None): def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=None):
super().__init__(data_dir, cache_dir, download_mode) super().__init__(data_dir, cache_dir, download_mode)
self.prompt = prompt self.prompt = prompt
def eos_token(self): def stopping_criteria(self):
raise NotImplementedError() """Denote where the generation should end.
For example, for coqa, this is '\nQ:' and for drop '.'.
By default, its None, meaning to generate up to max or EOT, whichever comes first.
"""
return None
def is_generation_task(self): def is_generation_task(self):
return ( return (
...@@ -650,11 +667,28 @@ class PromptSourceTask(Task): ...@@ -650,11 +667,28 @@ class PromptSourceTask(Task):
or "ROUGE" in self.prompt.metadata.metrics or "ROUGE" in self.prompt.metadata.metrics
) )
def doc_to_target(self, doc): def invalid_doc_for_prompt(self, doc) -> bool:
"""Some prompts may not work for some documents."""
if (
# generate_paraphrase for mrpc
# This generation prompt assumes a positive example. We filter out the negative examples.
# https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L7
# https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L88
(
self.prompt.id == "3b88d2c4-0aeb-4c6d-9ccc-653a388250a5"
or self.prompt.id == "d830d7a5-abc0-4275-ac62-974e0088876f"
)
and doc["label"] == 0
):
return True
return False
def doc_to_target(self, doc) -> str:
"""NOTE: In the future, this may return Union[str, List[str]]."""
_, target = self.prompt.apply(doc) _, target = self.prompt.apply(doc)
return f" {target}" return f" {target}"
def doc_to_text(self, doc): def doc_to_text(self, doc) -> str:
text, _ = self.prompt.apply(doc) text, _ = self.prompt.apply(doc)
return text return text
...@@ -684,7 +718,7 @@ class PromptSourceTask(Task): ...@@ -684,7 +718,7 @@ class PromptSourceTask(Task):
_requests.append(ll_answer_choice) _requests.append(ll_answer_choice)
else: else:
# TODO(Albert): What is the stop symbol? Is it model specific? # TODO(Albert): What is the stop symbol? Is it model specific?
cont_request = rf.greedy_until(ctx, [self.eos_token()]) cont_request = rf.greedy_until(ctx, [self.stopping_criteria()])
_requests.append(cont_request) _requests.append(cont_request)
return _requests return _requests
...@@ -699,9 +733,6 @@ class PromptSourceTask(Task): ...@@ -699,9 +733,6 @@ class PromptSourceTask(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# raise NotImplementedError(
# "Implement process results using the `prompt.metadata.metrics`. See below."
# )
target = self.doc_to_target(doc).strip() target = self.doc_to_target(doc).strip()
answer_choices_list = self.prompt.get_answer_choices_list(doc) answer_choices_list = self.prompt.get_answer_choices_list(doc)
if answer_choices_list: if answer_choices_list:
...@@ -710,29 +741,57 @@ class PromptSourceTask(Task): ...@@ -710,29 +741,57 @@ class PromptSourceTask(Task):
), f"We expect this to be a ranked choice task; double check please." ), f"We expect this to be a ranked choice task; double check please."
pred = answer_choices_list[np.argmax(results)] pred = answer_choices_list[np.argmax(results)]
out = {} out = {}
if "Accuracy" in self.prompt.metadata.metrics:
for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy":
out["acc"] = pred == target out["acc"] = pred == target
# TODO: Add metrics here. # TODO: Add metrics here.
return out return out
else: else:
raise NotImplementedError("Generation is not implemented yet.") # NOTE: In the future, target may be a list, not a string.
pred = results[0].strip()
out = {}
# Map metric name to HF metric. for metric in self.prompt.metadata.metrics:
# TODO(Albert): What is Other? assert (
# metric_names = prompt.metadata.metrics metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "BLEU":
out["bleu"] = (target, pred)
if metric == "ROUGE":
print("WARNING: Skipping Rouge.")
return out
def higher_is_better(self): def higher_is_better(self):
out = {} out = {}
if "Accuracy" in self.prompt.metadata.metrics: for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy":
out["acc"] = True out["acc"] = True
if metric == "BLEU":
out["bleu"] = True
if metric == "ROUGE":
print("WARNING: Skipping Rouge.")
return out return out
def aggregation(self): def aggregation(self):
out = {} out = {}
if "Accuracy" in self.prompt.metadata.metrics: for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy":
out["acc"] = mean out["acc"] = mean
if metric == "BLEU":
out["bleu"] = metrics.bleu
if metric == "ROUGE":
print("WARNING: Skipping Rouge.")
return out return out
......
...@@ -2,6 +2,7 @@ import collections ...@@ -2,6 +2,7 @@ import collections
import itertools import itertools
import pathlib import pathlib
import random import random
import lm_eval.metrics import lm_eval.metrics
import lm_eval.models import lm_eval.models
import lm_eval.tasks import lm_eval.tasks
...@@ -199,6 +200,9 @@ def evaluate( ...@@ -199,6 +200,9 @@ def evaluate(
) )
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
if task.invalid_doc_for_prompt(doc):
continue
docs[(task_prompt_name, doc_id)] = doc docs[(task_prompt_name, doc_id)] = doc
ctx = task.fewshot_context( ctx = task.fewshot_context(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
......
...@@ -90,7 +90,7 @@ class CoQA(PromptSourceTask): ...@@ -90,7 +90,7 @@ class CoQA(PromptSourceTask):
"f1": f1_sum / max(1, len(gold_list)), "f1": f1_sum / max(1, len(gold_list)),
} }
def eos_token(self): def stopping_criteria(self):
return "\nQ:" return "\nQ:"
# def construct_requests(self, doc, ctx): # def construct_requests(self, doc, ctx):
......
...@@ -92,7 +92,7 @@ class DROP(PromptSourceTask): ...@@ -92,7 +92,7 @@ class DROP(PromptSourceTask):
# """ # """
# conts = [rf.greedy_until(ctx, ["."])] # conts = [rf.greedy_until(ctx, ["."])]
# return conts # return conts
def eos_token(self): def stopping_criteria(self):
return "." return "."
def process_results(self, doc, results): def process_results(self, doc, results):
......
...@@ -236,6 +236,9 @@ class MRPC(PromptSourceTask): ...@@ -236,6 +236,9 @@ class MRPC(PromptSourceTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def stopping_criteria(self):
return "\n"
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.dataset["train"]) self._training_docs = list(self.dataset["train"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment