Commit d4c00093 authored by cjlovering's avatar cjlovering
Browse files

Added default behavior for bleu to the promtsourcetask class

parent f39c27c2
...@@ -14,6 +14,7 @@ from tqdm import tqdm ...@@ -14,6 +14,7 @@ from tqdm import tqdm
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from lm_eval import metrics
from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
from lm_eval import utils from lm_eval import utils
from abc import abstractmethod from abc import abstractmethod
...@@ -637,6 +638,16 @@ class Task(abc.ABC): ...@@ -637,6 +638,16 @@ class Task(abc.ABC):
class PromptSourceTask(Task): class PromptSourceTask(Task):
"""These are the metrics from promptsource that we have
added default behavior for. If you want to add default behavior for a new metric,
update the functions below. If you want to use one of the following metrics,
*and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`.
WARNING: ROUGE is WIP.
"""
CONFIGURED_PS_METRICS = set(["Accuracy", "BLEU", "ROUGE"])
def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=None): def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=None):
super().__init__(data_dir, cache_dir, download_mode) super().__init__(data_dir, cache_dir, download_mode)
self.prompt = prompt self.prompt = prompt
...@@ -737,29 +748,60 @@ class PromptSourceTask(Task): ...@@ -737,29 +748,60 @@ class PromptSourceTask(Task):
), f"We expect this to be a ranked choice task; double check please." ), f"We expect this to be a ranked choice task; double check please."
pred = answer_choices_list[np.argmax(results)] pred = answer_choices_list[np.argmax(results)]
out = {} out = {}
if "Accuracy" in self.prompt.metadata.metrics:
out["acc"] = pred == target for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy":
out["acc"] = pred == target
# TODO: Add metrics here. # TODO: Add metrics here.
return out return out
else: else:
raise NotImplementedError("Generation is not implemented yet.") # NOTE: In the future, target may be a list, not a string.
pred = results[0].strip()
out = {}
for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "BLEU":
out["bleu"] = (target, pred)
if metric == "ROUGE":
print("WARNING: Skipping Rouge.")
return out
# Map metric name to HF metric. # Map metric name to HF metric.
# TODO(Albert): What is Other? # TODO(Albert): What is Other?
# metric_names = prompt.metadata.metrics # metric_names = prompt.metadata.metrics
def higher_is_better(self): def higher_is_better(self):
out = {} out = {}
if "Accuracy" in self.prompt.metadata.metrics: for metric in self.prompt.metadata.metrics:
out["acc"] = True assert (
metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy":
out["acc"] = True
if metric == "BLEU":
out["bleu"] = True
if metric == "ROUGE":
print("WARNING: Skipping Rouge.")
return out return out
def aggregation(self): def aggregation(self):
out = {} out = {}
if "Accuracy" in self.prompt.metadata.metrics: for metric in self.prompt.metadata.metrics:
out["acc"] = mean assert (
metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy":
out["acc"] = mean
if metric == "BLEU":
out["bleu"] = metrics.bleu
if metric == "ROUGE":
print("WARNING: Skipping Rouge.")
return out return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment