Unverified Commit 54999199 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #2 from cjlovering/master

Pulling eval harness updates
parents 6caa0afd 18af502b
import abc import abc
from typing import Iterable from typing import Iterable, Optional
import promptsource
import numpy as np import numpy as np
import random import random
import re import re
...@@ -12,6 +14,7 @@ from tqdm import tqdm ...@@ -12,6 +14,7 @@ from tqdm import tqdm
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from lm_eval import metrics
from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
from lm_eval import utils from lm_eval import utils
from abc import abstractmethod from abc import abstractmethod
...@@ -24,17 +27,17 @@ class LM(abc.ABC): ...@@ -24,17 +27,17 @@ class LM(abc.ABC):
@abstractmethod @abstractmethod
def loglikelihood(self, requests): def loglikelihood(self, requests):
"""Compute log-likelihood of generating a continuation from a context. """Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible. LM calls whenever possible.
:param requests: list :param requests: list
A list of pairs (context, continuation) A list of pairs (context, continuation)
context: str context: str
Context string. Implementations of LM must be able to handle an Context string. Implementations of LM must be able to handle an
empty context string. empty context string.
continuation: str continuation: str
The continuation over which log likelihood will be calculated. If The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation. there is a word boundary, the space should be in the continuation.
For example, context="hello" continuation=" world" is correct. For example, context="hello" continuation=" world" is correct.
:return: list :return: list
A list of pairs (logprob, isgreedy) A list of pairs (logprob, isgreedy)
...@@ -97,7 +100,7 @@ class LM(abc.ABC): ...@@ -97,7 +100,7 @@ class LM(abc.ABC):
context: str context: str
Context string Context string
until: [str] until: [str]
The string sequences to generate until. These string sequences The string sequences to generate until. These string sequences
may each span across multiple tokens, or may be part of one token. may each span across multiple tokens, or may be part of one token.
:return: list :return: list
A list of strings continuation A list of strings continuation
...@@ -118,7 +121,6 @@ class LM(abc.ABC): ...@@ -118,7 +121,6 @@ class LM(abc.ABC):
class BaseLM(LM): class BaseLM(LM):
@property @property
@abstractmethod @abstractmethod
def eot_token_id(self): def eot_token_id(self):
...@@ -145,13 +147,16 @@ class BaseLM(LM): ...@@ -145,13 +147,16 @@ class BaseLM(LM):
pass pass
@abstractmethod @abstractmethod
def tok_encode(self, string: str): pass def tok_encode(self, string: str):
pass
@abstractmethod @abstractmethod
def tok_decode(self, tokens: Iterable[int]): pass def tok_decode(self, tokens: Iterable[int]):
pass
@abstractmethod @abstractmethod
def _model_generate(self, context, max_length, eos_token_id): pass def _model_generate(self, context, max_length, eos_token_id):
pass
@abstractmethod @abstractmethod
def _model_call(self, inps): def _model_call(self, inps):
...@@ -187,23 +192,30 @@ class BaseLM(LM): ...@@ -187,23 +192,30 @@ class BaseLM(LM):
# TODO: automatic batch size detection for vectorization # TODO: automatic batch size detection for vectorization
loglikelihoods = [] loglikelihoods = []
for string, in tqdm(requests): for (string,) in tqdm(requests):
rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows( rolling_token_windows = list(
token_list=self.tok_encode(string), map(
prefix_token=self.eot_token_id, utils.make_disjoint_window,
max_seq_len=self.max_length, utils.get_rolling_token_windows(
context_len=1, token_list=self.tok_encode(string),
))) prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
)
rolling_token_windows = [(None,) + x for x in rolling_token_windows] rolling_token_windows = [(None,) + x for x in rolling_token_windows]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
# that # that
string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True) string_nll = self._loglikelihood_tokens(
rolling_token_windows, disable_tqdm=True
)
# discard is_greedy # discard is_greedy
string_nll = [x[0] for x in string_nll] string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll) string_nll = sum(string_nll)
loglikelihoods.append(string_nll) loglikelihoods.append(string_nll)
...@@ -223,10 +235,12 @@ class BaseLM(LM): ...@@ -223,10 +235,12 @@ class BaseLM(LM):
toks = x[1] + x[2] toks = x[1] + x[2]
return -len(toks), tuple(toks) return -len(toks), tuple(toks)
# TODO: automatic (variable) batch size detection for vectorization # TODO: automatic (variable) batch size detection for vectorization
reord = utils.Reorderer(requests, _collate) reord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size): for chunk in utils.chunks(
tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size
):
inps = [] inps = []
cont_toks_list = [] cont_toks_list = []
inplens = [] inplens = []
...@@ -252,44 +266,60 @@ class BaseLM(LM): ...@@ -252,44 +266,60 @@ class BaseLM(LM):
# when too long to fit in context, truncate from the left # when too long to fit in context, truncate from the left
inp = torch.tensor( inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length+1):][:-1], (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long dtype=torch.long,
).to(self.device) ).to(self.device)
inplen, = inp.shape (inplen,) = inp.shape
cont = continuation_enc cont = continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one. # since in _collate we make sure length is descending, the longest is always the first one.
padding_length = padding_length if padding_length is not None else inplen padding_length = (
padding_length if padding_length is not None else inplen
)
# pad length from seq to padding_length # pad length from seq to padding_length
inp = torch.cat([ inp = torch.cat(
inp, # [seq] [
torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq] inp, # [seq]
], dim=0) torch.zeros(padding_length - inplen, dtype=torch.long).to(
inp.device
), # [padding_length - seq]
],
dim=0,
)
inps.append(inp.unsqueeze(0)) # [1, padding_length] inps.append(inp.unsqueeze(0)) # [1, padding_length]
cont_toks_list.append(cont) cont_toks_list.append(cont)
inplens.append(inplen) inplens.append(inplen)
batched_inps = torch.cat(inps, dim=0) # [batch, padding_length batched_inps = torch.cat(inps, dim=0) # [batch, padding_length
multi_logits = F.log_softmax(self._model_call(batched_inps), dim=-1).cpu() # [batch, padding_length, vocab] multi_logits = F.log_softmax(
self._model_call(batched_inps), dim=-1
).cpu() # [batch, padding_length, vocab]
for (cache_key, _, _), logits, inp, inplen, cont_toks \ for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
in zip(chunk, multi_logits, inps, inplens, cont_toks_list): chunk, multi_logits, inps, inplens, cont_toks_list
):
# Slice to original seq length # Slice to original seq length
contlen = len(cont_toks) contlen = len(cont_toks)
logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab] logits = logits[inplen - contlen : inplen].unsqueeze(
0
) # [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation # Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1) greedy_tokens = logits.argmax(dim=-1)
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0) # [1, seq] cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
0
) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all() max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices # Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist() # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
-1
) # [1, seq]
# Answer: (log prob, is-exact-match) # Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal)) answer = (float(logits.sum()), bool(max_equal))
...@@ -301,9 +331,9 @@ class BaseLM(LM): ...@@ -301,9 +331,9 @@ class BaseLM(LM):
res.append(answer) res.append(answer)
return reord.get_original(res) return reord.get_original(res)
def greedy_until(self, requests): def greedy_until(self, requests):
# TODO: implement fully general `until` that handles untils that are # TODO: implement fully general `until` that handles untils that are
# multiple tokens or that span multiple tokens correctly # multiple tokens or that span multiple tokens correctly
# TODO: extract to TokenizedLM? # TODO: extract to TokenizedLM?
...@@ -312,29 +342,46 @@ class BaseLM(LM): ...@@ -312,29 +342,46 @@ class BaseLM(LM):
def _collate(x): def _collate(x):
toks = self.tok_encode(x[0]) toks = self.tok_encode(x[0])
return len(toks), x[0] return len(toks), x[0]
reord = utils.Reorderer(requests, _collate) reord = utils.Reorderer(requests, _collate)
for context, until in tqdm(reord.get_reordered()): for context, request_args in tqdm(reord.get_reordered()):
if isinstance(until, str): stopping_criteria = request_args["stopping_criteria"]
until = [until] max_generation_length = request_args["max_generation_length"]
primary_until, = self.tok_encode(until[0]) assert isinstance(stopping_criteria, str) or stopping_criteria is None
assert (
context_enc = torch.tensor([self.tok_encode(context)[self.max_gen_toks - self.max_length:]]).to(self.device) isinstance(max_generation_length, int) or max_generation_length is None
)
cont = self._model_generate(context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until) until = [stopping_criteria]
primary_until = self.tok_encode(until[0])
context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device)
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1]:]) if max_generation_length is None:
max_length = context_enc.shape[1] + self.max_gen_toks
else:
max_length = min(
max_generation_length, context_enc.shape[1] + self.max_gen_toks
)
cont = self._model_generate(
context_enc,
max_length,
torch.tensor(primary_until),
)
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
for term in until: for term in until:
s = s.split(term)[0] s = s.split(term)[0]
# partial caching # partial caching
self.cache_hook.add_partial("greedy_until", (context, until), s) self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s) res.append(s)
return reord.get_original(res) return reord.get_original(res)
...@@ -383,7 +430,7 @@ class Task(abc.ABC): ...@@ -383,7 +430,7 @@ class Task(abc.ABC):
self._fewshot_docs = None self._fewshot_docs = None
def download(self, data_dir=None, cache_dir=None, download_mode=None): def download(self, data_dir=None, cache_dir=None, download_mode=None):
""" Downloads and returns the task dataset. """Downloads and returns the task dataset.
Override this method to download the dataset from a custom API. Override this method to download the dataset from a custom API.
:param data_dir: str :param data_dir: str
...@@ -412,7 +459,7 @@ class Task(abc.ABC): ...@@ -412,7 +459,7 @@ class Task(abc.ABC):
name=self.DATASET_NAME, name=self.DATASET_NAME,
data_dir=data_dir, data_dir=data_dir,
cache_dir=cache_dir, cache_dir=cache_dir,
download_mode=download_mode download_mode=download_mode,
) )
@abstractmethod @abstractmethod
...@@ -478,22 +525,22 @@ class Task(abc.ABC): ...@@ -478,22 +525,22 @@ class Task(abc.ABC):
@abstractmethod @abstractmethod
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str :param ctx: str
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
pass pass
@abstractmethod @abstractmethod
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
:param doc: :param doc:
...@@ -507,7 +554,7 @@ class Task(abc.ABC): ...@@ -507,7 +554,7 @@ class Task(abc.ABC):
def aggregation(self): def aggregation(self):
""" """
:returns: {str: [metric_score] -> float} :returns: {str: [metric_score] -> float}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metric scores functions that aggregate a list of metric scores
""" """
pass pass
...@@ -516,22 +563,26 @@ class Task(abc.ABC): ...@@ -516,22 +563,26 @@ class Task(abc.ABC):
def higher_is_better(self): def higher_is_better(self):
""" """
:returns: {str: bool} :returns: {str: bool}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
pass pass
def fewshot_description(self): def fewshot_description(self):
import warnings import warnings
warnings.warn( warnings.warn(
"`fewshot_description` will be removed in futures versions. Pass " "`fewshot_description` will be removed in futures versions. Pass "
"any custom descriptions to the `evaluate` function instead.", "any custom descriptions to the `evaluate` function instead.",
DeprecationWarning) DeprecationWarning,
)
return "" return ""
@utils.positional_deprecated @utils.positional_deprecated
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(
""" Returns a fewshot context string that is made up of a prepended description self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example. (if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str :param doc: str
...@@ -548,7 +599,9 @@ class Task(abc.ABC): ...@@ -548,7 +599,9 @@ class Task(abc.ABC):
:returns: str :returns: str
The fewshot context. The fewshot context.
""" """
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`" assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, ( assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend " "The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the " "a custom description to the context, supply the corresponding string via the "
...@@ -556,7 +609,9 @@ class Task(abc.ABC): ...@@ -556,7 +609,9 @@ class Task(abc.ABC):
) )
if provide_description is not None: if provide_description is not None:
# nudge people to not specify it at all # nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description = description + "\n\n" if description else "" description = description + "\n\n" if description else ""
...@@ -569,31 +624,229 @@ class Task(abc.ABC): ...@@ -569,31 +624,229 @@ class Task(abc.ABC):
else: else:
if self._fewshot_docs is None: if self._fewshot_docs is None:
self._fewshot_docs = list( self._fewshot_docs = list(
self.validation_docs() if self.has_validation_docs() else self.test_docs() self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
) )
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1) fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot # get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot] fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = "\n\n".join( labeled_examples = (
[self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex] "\n\n".join(
) + "\n\n" [
self.doc_to_text(doc) + self.doc_to_target(doc)
for doc in fewshotex
]
)
+ "\n\n"
)
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
return description + labeled_examples + example return description + labeled_examples + example
class MultipleChoiceTask(Task): class PromptSourceTask(Task):
"""These are the metrics from promptsource that we have
added default behavior for. If you want to add default behavior for a new metric,
update the functions below. If you want to use one of the following metrics,
*and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`.
"""
CONFIGURED_PS_METRICS = set(["Accuracy", "BLEU", "ROUGE"])
def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=None):
super().__init__(data_dir, cache_dir, download_mode)
self.prompt = prompt
def stopping_criteria(self) -> Optional[str]:
"""Denote where the generation should end.
For example, for coqa, this is '\nQ:' and for drop '.'.
By default, its None, meaning to generate up to max or EOT, whichever comes first.
"""
return None
def max_generation_length(self) -> Optional[int]:
"""Denote where the max length of the generation if it is obvious from the task."""
return None
def invalid_doc_for_prompt(self, doc) -> bool:
"""Some prompts may not work for some documents."""
if (
# generate_paraphrase for mrpc
# This generation prompt assumes a positive example. We filter out the negative examples.
# https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L7
# https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L88
(
self.prompt.id == "3b88d2c4-0aeb-4c6d-9ccc-653a388250a5"
or self.prompt.id == "d830d7a5-abc0-4275-ac62-974e0088876f"
)
and doc["label"] == 0
):
return True
return False
def doc_to_target(self, doc) -> str:
"""NOTE: In the future, this may return Union[str, List[str]]."""
_, target = self.prompt.apply(doc)
return f" {target}"
def doc_to_text(self, doc) -> str:
text, _ = self.prompt.apply(doc)
return text
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
_requests = []
answer_choices_list = self.prompt.get_answer_choices_list(doc)
if answer_choices_list:
# If answer_choices_list, then this is a ranked choice prompt.
for answer_choice in answer_choices_list:
ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}")
_requests.append(ll_answer_choice)
else:
# If not, then this is a generation prompt.
# NOTE: In the future, target will be a list of strings.
request_args = {
"stopping_criteria": self.stopping_criteria(),
"max_generation_length": self.max_generation_length(),
}
cont_request = rf.greedy_until(ctx, request_args)
_requests.append(cont_request)
return _requests
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
target = self.doc_to_target(doc).strip()
answer_choices_list = self.prompt.get_answer_choices_list(doc)
if answer_choices_list:
# If answer_choices_list, then this is a ranked choice prompt.
# NOTE: In the future, target will be a list of strings.
# For now, we can assume there will be only 1 target, but its possible
# that this not the case so we should check for that.
pred = answer_choices_list[np.argmax(results)]
out = {}
for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy":
out["acc"] = pred == target
# TODO: Add metrics here.
return out
else:
# If not, then this is a generation prompt.
# NOTE: In the future, target will be a list of strings.
pred = results[0].strip()
out = {}
for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "BLEU":
out["bleu"] = (target, pred)
if metric == "ROUGE":
# TODO: This computes all rouge sub-metrics. Find a generic
# way to handle user specified rouge sub-metrics to avoid extra
# compute.
rouge_scores = metrics.rouge(target, pred)
# Flatten rouge score dict.
rouge_scores = utils.flatten(rouge_scores)
# Merge all the rouge-type scores into the `out` dict.
out = {**out, **rouge_scores}
print(out)
return out
def higher_is_better(self):
out = {}
for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy":
out["acc"] = True
if metric == "BLEU":
out["bleu"] = True
if metric == "ROUGE":
# TODO: Find a generic way to handle user specified rouge metrics.
out["rouge1_precision"] = True
out["rouge1_recall"] = True
out["rouge1_fmeasure"] = True
out["rouge2_precision"] = True
out["rouge2_recall"] = True
out["rouge2_fmeasure"] = True
out["rougeL_precision"] = True
out["rougeL_recall"] = True
out["rougeL_fmeasure"] = True
out["rougeLsum_precision"] = True
out["rougeLsum_recall"] = True
out["rougeLsum_fmeasure"] = True
return out
def aggregation(self):
out = {}
for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy":
out["acc"] = mean
if metric == "BLEU":
out["bleu"] = metrics.bleu
if metric == "ROUGE":
# TODO: Find a generic way to handle user specified rouge metrics.
out["rouge1_precision"] = mean
out["rouge1_recall"] = mean
out["rouge1_fmeasure"] = mean
out["rouge2_precision"] = mean
out["rouge2_recall"] = mean
out["rouge2_fmeasure"] = mean
out["rougeL_precision"] = mean
out["rougeL_recall"] = mean
out["rougeL_fmeasure"] = mean
out["rougeLsum_precision"] = mean
out["rougeLsum_recall"] = mean
out["rougeLsum_fmeasure"] = mean
return out
class MultipleChoiceTask(Task):
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc['choices'][doc['gold']] return " " + doc["choices"][doc["gold"]]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
lls = [ lls = [
rf.loglikelihood(ctx, " {}".format(choice))[0] rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
for choice in doc['choices']
] ]
return lls return lls
...@@ -601,21 +854,21 @@ class MultipleChoiceTask(Task): ...@@ -601,21 +854,21 @@ class MultipleChoiceTask(Task):
def process_results(self, doc, results): def process_results(self, doc, results):
gold = doc["gold"] gold = doc["gold"]
acc = 1. if np.argmax(results) == gold else 0. acc = 1.0 if np.argmax(results) == gold else 0.0
completion_len = np.array([float(len(i)) for i in doc["choices"]]) completion_len = np.array([float(len(i)) for i in doc["choices"]])
acc_norm = 1. if np.argmax(results / completion_len) == gold else 0. acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
return { return {
"acc": acc, "acc": acc,
"acc_norm": acc_norm, "acc_norm": acc_norm,
} }
def higher_is_better(self): def higher_is_better(self):
return { return {
"acc": True, "acc": True,
"acc_norm": True, "acc_norm": True,
} }
def aggregation(self): def aggregation(self):
return { return {
"acc": mean, "acc": mean,
...@@ -624,7 +877,6 @@ class MultipleChoiceTask(Task): ...@@ -624,7 +877,6 @@ class MultipleChoiceTask(Task):
class PerplexityTask(Task, abc.ABC): class PerplexityTask(Task, abc.ABC):
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -632,9 +884,15 @@ class PerplexityTask(Task, abc.ABC): ...@@ -632,9 +884,15 @@ class PerplexityTask(Task, abc.ABC):
assert k == 0 assert k == 0
return [] return []
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): def fewshot_context(
assert num_fewshot == 0, "The number of fewshot examples must be 0 for perplexity tasks." self, doc, num_fewshot, provide_description=None, rnd=None, description=None
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`." ):
assert (
num_fewshot == 0
), "The number of fewshot examples must be 0 for perplexity tasks."
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`."
assert not provide_description, ( assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend " "The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the " "a custom description to the context, supply the corresponding string via the "
...@@ -642,7 +900,9 @@ class PerplexityTask(Task, abc.ABC): ...@@ -642,7 +900,9 @@ class PerplexityTask(Task, abc.ABC):
) )
if provide_description is not None: if provide_description is not None:
# nudge people to not specify it at all # nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
return "" return ""
...@@ -665,7 +925,7 @@ class PerplexityTask(Task, abc.ABC): ...@@ -665,7 +925,7 @@ class PerplexityTask(Task, abc.ABC):
return req return req
def process_results(self, doc, results): def process_results(self, doc, results):
loglikelihood, = results (loglikelihood,) = results
words = self.count_words(doc) words = self.count_words(doc)
bytes_ = self.count_bytes(doc) bytes_ = self.count_bytes(doc)
return { return {
...@@ -687,23 +947,23 @@ class PerplexityTask(Task, abc.ABC): ...@@ -687,23 +947,23 @@ class PerplexityTask(Task, abc.ABC):
@classmethod @classmethod
def count_words(cls, doc): def count_words(cls, doc):
""" Downstream tasks with custom word boundaries should override this! """ """Downstream tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc)) return len(re.split(r"\s+", doc))
def hash_args(attr, args): def hash_args(attr, args):
dat = json.dumps([attr] + list(args)) dat = json.dumps([attr] + list(args))
return hashlib.sha256(dat.encode('utf-8')).hexdigest() return hashlib.sha256(dat.encode("utf-8")).hexdigest()
class CacheHook: class CacheHook:
def __init__(self, cachinglm): def __init__(self, cachinglm):
if cachinglm is None: if cachinglm is None:
self.dbdict = None self.dbdict = None
return return
self.dbdict = cachinglm.dbdict self.dbdict = cachinglm.dbdict
def add_partial(self, attr, req, res): def add_partial(self, attr, req, res):
if self.dbdict is None: if self.dbdict is None:
return return
...@@ -733,7 +993,7 @@ class CachingLM: ...@@ -733,7 +993,7 @@ class CachingLM:
def fn(requests): def fn(requests):
res = [] res = []
remaining_reqs = [] remaining_reqs = []
# figure out which ones are cached and which ones are new # figure out which ones are cached and which ones are new
for req in requests: for req in requests:
hsh = hash_args(attr, req) hsh = hash_args(attr, req)
...@@ -746,7 +1006,7 @@ class CachingLM: ...@@ -746,7 +1006,7 @@ class CachingLM:
else: else:
res.append(None) res.append(None)
remaining_reqs.append(req) remaining_reqs.append(req)
# actually run the LM on the requests that do not have cached results # actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs) rem_res = getattr(self.lm, attr)(remaining_reqs)
...@@ -764,41 +1024,48 @@ class CachingLM: ...@@ -764,41 +1024,48 @@ class CachingLM:
self.dbdict.commit() self.dbdict.commit()
return res return res
return fn return fn
def get_cache_hook(self): def get_cache_hook(self):
return CacheHook(self) return CacheHook(self)
REQUEST_RETURN_LENGTHS = { REQUEST_RETURN_LENGTHS = {
'loglikelihood': 2, "loglikelihood": 2,
'greedy_until': None, "greedy_until": None,
'loglikelihood_rolling': None, "loglikelihood_rolling": None,
} }
class Request: class Request:
def __init__(self, request_type, args, index=None): def __init__(self, request_type, args, index=None):
if request_type not in REQUEST_RETURN_LENGTHS.keys(): if request_type not in REQUEST_RETURN_LENGTHS.keys():
raise NotImplementedError('The request type {} is not implemented!'.format(request_type)) raise NotImplementedError(
"The request type {} is not implemented!".format(request_type)
)
self.request_type = request_type self.request_type = request_type
self.args = args self.args = args
self.index = index self.index = index
def __iter__(self): def __iter__(self):
if REQUEST_RETURN_LENGTHS[self.request_type] is None: if REQUEST_RETURN_LENGTHS[self.request_type] is None:
raise IndexError('This request type does not return multiple arguments!') raise IndexError("This request type does not return multiple arguments!")
for i in range(REQUEST_RETURN_LENGTHS[self.request_type]): for i in range(REQUEST_RETURN_LENGTHS[self.request_type]):
yield Request(self.request_type, self.args, i) yield Request(self.request_type, self.args, i)
def __getitem__(self, i): def __getitem__(self, i):
if REQUEST_RETURN_LENGTHS[self.request_type] is None: if REQUEST_RETURN_LENGTHS[self.request_type] is None:
raise IndexError('This request type does not return multiple arguments!') raise IndexError("This request type does not return multiple arguments!")
return Request(self.request_type, self.args, i) return Request(self.request_type, self.args, i)
def __eq__(self, other): def __eq__(self, other):
return self.request_type == other.request_type and self.args == other.args and self.index == other.index return (
self.request_type == other.request_type
and self.args == other.args
and self.index == other.index
)
def __repr__(self): def __repr__(self):
return f"Req_{self.request_type}{self.args}[{self.index}]\n" return f"Req_{self.request_type}{self.args}[{self.index}]\n"
...@@ -808,6 +1075,7 @@ class RequestFactory: ...@@ -808,6 +1075,7 @@ class RequestFactory:
def __getattr__(self, attr): def __getattr__(self, attr):
def fn(*args): def fn(*args):
return Request(attr, args) return Request(attr, args)
return fn return fn
......
...@@ -2,25 +2,38 @@ import collections ...@@ -2,25 +2,38 @@ import collections
import itertools import itertools
import pathlib import pathlib
import random import random
import lm_eval.metrics import lm_eval.metrics
import lm_eval.models import lm_eval.models
import lm_eval.tasks import lm_eval.tasks
import lm_eval.base import lm_eval.base
import promptsource
import numpy as np import numpy as np
from promptsource.templates import DatasetTemplates
from lm_eval.utils import positional_deprecated, run_task_tests from lm_eval.utils import positional_deprecated, run_task_tests
@positional_deprecated @positional_deprecated
def simple_evaluate(model, model_args=None, tasks=[], def simple_evaluate(
num_fewshot=0, batch_size=None, device=None, model,
no_cache=False, limit=None, bootstrap_iters=100000, model_args=None,
description_dict=None, check_integrity=False): tasks=[],
num_fewshot=0,
batch_size=None,
device=None,
no_cache=False,
limit=None,
bootstrap_iters=100000,
description_dict=None,
check_integrity=False,
):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM] :param model: Union[str, LM]
Name of model or LM object, see lm_eval.models.get_model Name of model or LM object, see lm_eval.models.get_model
:param model_args: Optional[str] :param model_args: Optional[str]
String arguments for each model class, see LM.create_from_arg_string. String arguments for each model class, see LM.create_from_arg_string.
Ignored if `model` argument is a LM object. Ignored if `model` argument is a LM object.
:param tasks: list[Union[str, Task]] :param tasks: list[Union[str, Task]]
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise. List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
...@@ -37,7 +50,7 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -37,7 +50,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
:param bootstrap_iters: :param bootstrap_iters:
Number of iterations for bootstrap statistics Number of iterations for bootstrap statistics
:param description_dict: dict[str, str] :param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description` Dictionary of custom task descriptions of the form: `task_name: description`
:param check_integrity: bool :param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks Whether to run the relevant part of the test suite for the tasks
:return :return
...@@ -49,20 +62,28 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -49,20 +62,28 @@ def simple_evaluate(model, model_args=None, tasks=[],
assert tasks != [], "No tasks specified" assert tasks != [], "No tasks specified"
if isinstance(model, str): if isinstance(model, str):
if model_args is None: model_args = "" if model_args is None:
lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, { model_args = ""
'batch_size': batch_size, 'device': device lm = lm_eval.models.get_model(model).create_from_arg_string(
}) model_args, {"batch_size": batch_size, "device": device}
)
else: else:
assert isinstance(model, lm_eval.base.LM) assert isinstance(model, lm_eval.base.LM)
lm = model lm = model
# TODO: Hard-code turning off cache while testing. Remove once testing is completed.
no_cache = True
if not no_cache: if not no_cache:
lm = lm_eval.base.CachingLM( lm = lm_eval.base.CachingLM(
lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db' lm,
"lm_cache/"
+ model
+ "_"
+ model_args.replace("=", "-").replace(",", "_").replace("/", "-")
+ ".db",
) )
task_dict = lm_eval.tasks.get_task_dict(tasks) task_dict = lm_eval.tasks.get_task_dict_promptsource(tasks)
if check_integrity: if check_integrity:
run_task_tests(task_list=tasks) run_task_tests(task_list=tasks)
...@@ -72,7 +93,7 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -72,7 +93,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
task_dict=task_dict, task_dict=task_dict,
num_fewshot=num_fewshot, num_fewshot=num_fewshot,
limit=limit, limit=limit,
description_dict=description_dict description_dict=description_dict,
) )
# add info about the model and few shot config # add info about the model and few shot config
...@@ -85,14 +106,22 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -85,14 +106,22 @@ def simple_evaluate(model, model_args=None, tasks=[],
"no_cache": no_cache, "no_cache": no_cache,
"limit": limit, "limit": limit,
"bootstrap_iters": bootstrap_iters, "bootstrap_iters": bootstrap_iters,
"description_dict": description_dict "description_dict": description_dict,
} }
return results return results
@positional_deprecated @positional_deprecated
def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None): def evaluate(
lm,
task_dict,
provide_description=None,
num_fewshot=0,
limit=None,
bootstrap_iters=100000,
description_dict=None,
):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param lm: obj :param lm: obj
...@@ -108,7 +137,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -108,7 +137,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
:param bootstrap_iters: :param bootstrap_iters:
Number of iterations for bootstrap statistics Number of iterations for bootstrap statistics
:param description_dict: dict[str, str] :param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description` Dictionary of custom task descriptions of the form: `task_name: description`
:return :return
Dictionary of results Dictionary of results
""" """
...@@ -118,12 +147,14 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -118,12 +147,14 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
assert not provide_description # not implemented. assert not provide_description # not implemented.
if provide_description is not None: if provide_description is not None:
# nudge people to not specify it at all # nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
task_dict_items = [ task_dict_items = [
(name, task) (name, task)
for name, task in task_dict.items() for name, task in task_dict.items()
if(task.has_validation_docs() or task.has_test_docs()) if (task.has_validation_docs() or task.has_test_docs())
] ]
results = collections.defaultdict(dict) results = collections.defaultdict(dict)
...@@ -141,8 +172,12 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -141,8 +172,12 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
docs = {} docs = {}
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict_items: for task_prompt_name, task in task_dict_items:
versions[task_name] = task.VERSION # if task.is_generation_task():
# print(f"WARNING: Skipping generation prompt {task.prompt.name}.")
# continue
versions[task_prompt_name] = task.VERSION
# default to test doc, fall back to val doc if validation unavailable # default to test doc, fall back to val doc if validation unavailable
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
if task.has_test_docs(): if task.has_test_docs():
...@@ -158,15 +193,19 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -158,15 +193,19 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
rnd.seed(42) rnd.seed(42)
rnd.shuffle(task_docs) rnd.shuffle(task_docs)
description = description_dict[task_name] if description_dict and task_name in description_dict else "" description = (
description_dict[task_prompt_name]
if description_dict and task_prompt_name in description_dict
else ""
)
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
docs[(task_name, doc_id)] = doc if task.invalid_doc_for_prompt(doc):
continue
docs[(task_prompt_name, doc_id)] = doc
ctx = task.fewshot_context( ctx = task.fewshot_context(
doc=doc, doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
num_fewshot=num_fewshot,
rnd=rnd,
description=description
) )
reqs = task.construct_requests(doc, ctx) reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)): if not isinstance(reqs, (list, tuple)):
...@@ -175,7 +214,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -175,7 +214,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
requests[req.request_type].append(req) requests[req.request_type].append(req)
# i: index in requests for a single task instance # i: index in requests for a single task instance
# doc_id: unique id that we can get back to a doc using `docs` # doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.request_type].append((i, task_name, doc, doc_id)) requests_origin[req.request_type].append(
(i, task_prompt_name, doc, doc_id)
)
# all responses for each (task, doc) # all responses for each (task, doc)
process_res_queue = collections.defaultdict(list) process_res_queue = collections.defaultdict(list)
...@@ -189,43 +230,49 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -189,43 +230,49 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
print("Running", reqtype, "requests") print("Running", reqtype, "requests")
resps = getattr(lm, reqtype)([req.args for req in reqs]) resps = getattr(lm, reqtype)([req.args for req in reqs])
resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)] resps = [
x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
]
for resp, (i, task_prompt_name, doc, doc_id) in zip(
resps, requests_origin[reqtype]
):
process_res_queue[(task_prompt_name, doc_id)].append((i, resp))
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp))
vals = collections.defaultdict(list) vals = collections.defaultdict(list)
# unpack results and sort back in order and return control to Task # unpack results and sort back in order and return control to Task
for (task_name, doc_id), requests in process_res_queue.items(): for (task_prompt_name, doc_id), requests in process_res_queue.items():
requests.sort(key=lambda x: x[0]) requests.sort(key=lambda x: x[0])
requests = [x[1] for x in requests] requests = [x[1] for x in requests]
task = task_dict[task_name] task = task_dict[task_prompt_name]
doc = docs[(task_name, doc_id)] doc = docs[(task_prompt_name, doc_id)]
metrics = task.process_results(doc, requests) metrics = task.process_results(doc, requests)
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_name, metric)].append(value) vals[(task_prompt_name, metric)].append(value)
# aggregate results # aggregate results
for (task_name, metric), items in vals.items(): for (task_prompt_name, metric), items in vals.items():
task = task_dict[task_name] task_name, prompt_name = task_prompt_name.split("+")
results[task_name][metric] = task.aggregation()[metric](items) results[task_prompt_name]["task_name"] = task_name
results[task_prompt_name]["prompt_name"] = prompt_name
task = task_dict[task_prompt_name]
results[task_prompt_name][metric] = task.aggregation()[metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
stderr = lm_eval.metrics.stderr_for_metric( stderr = lm_eval.metrics.stderr_for_metric(
metric=task.aggregation()[metric], metric=task.aggregation()[metric],
bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters, bootstrap_iters=min(bootstrap_iters, 1000)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
) )
if stderr is not None: if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items) results[task_prompt_name][metric + "_stderr"] = stderr(items)
return { return {"results": dict(results), "versions": dict(versions)}
"results": dict(results),
"versions": dict(versions)
}
def make_table(result_dict): def make_table(result_dict):
...@@ -234,22 +281,50 @@ def make_table(result_dict): ...@@ -234,22 +281,50 @@ def make_table(result_dict):
md_writer = MarkdownTableWriter() md_writer = MarkdownTableWriter()
latex_writer = LatexTableWriter() latex_writer = LatexTableWriter()
md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"] md_writer.headers = ["Task", "Prompt", "Version", "Metric", "Value", "", "Stderr"]
latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"] latex_writer.headers = [
"Task",
"Prompt",
"Version",
"Metric",
"Value",
"",
"Stderr",
]
values = [] values = []
for k, dic in result_dict["results"].items(): for k, dic in result_dict["results"].items():
version = result_dict["versions"][k] version = result_dict["versions"][k]
for m, v in dic.items(): for m, v in dic.items():
if m.endswith("_stderr"): if m.endswith("_stderr"):
continue continue
if "_name" in m:
continue
if m + "_stderr" in dic: if m + "_stderr" in dic:
se = dic[m + "_stderr"] se = dic[m + "_stderr"]
values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se]) values.append(
[
dic["task_name"],
dic["prompt_name"],
version,
m,
"%.4f" % v,
"±",
"%.4f" % se,
]
)
else: else:
values.append([k, version, m, '%.4f' % v, '', '']) values.append(
[
dic["task_name"],
dic["prompt_name"],
version,
m,
"%.4f" % v,
"",
"",
]
)
k = "" k = ""
version = "" version = ""
md_writer.value_matrix = values md_writer.value_matrix = values
......
import typing
import math import math
from collections.abc import Iterable from collections.abc import Iterable
import numpy as np import numpy as np
import sacrebleu import sacrebleu
from rouge_score import rouge_scorer
import sklearn.metrics import sklearn.metrics
import random import random
...@@ -184,6 +186,74 @@ def _sacreformat(refs, preds): ...@@ -184,6 +186,74 @@ def _sacreformat(refs, preds):
return refs, preds return refs, preds
def rouge(
refs: typing.List[str],
pred: str,
rouge_types: typing.List[str] = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
):
""" ROUGE with multi-reference support
Implementation based on GEM-metrics:
https://github.com/GEM-benchmark/GEM-metrics/blob/431a8174bd6b3637e8d6118bfad2983e39e99733/gem_metrics/rouge.py
:param refs:
A `list` of reference `str`s.
:param pred:
A single prediction `str`s.
"""
# Add newlines between sentences to correctly compute `rougeLsum`.
if "rougeLsum" in rouge_types:
# TODO: Adapt this to handle languages that do not support sentence endings by `.`.
# See GEM-metrics implementation with lang specific `nltk` tokenizers to
# split sentences.
pred = pred.replace(".", ".\n")
refs = [ref.replace(".", ".\n") for ref in refs]
scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=True)
# ROUGE multi-ref jackknifing
if len(refs) > 1:
cur_scores = [scorer.score(ref, pred) for ref in refs]
# get best score for all leave-one-out sets
best_scores = []
for leave in range(len(refs)):
cur_scores_leave_one = [
cur_scores[s] for s in range(len(refs)) if s != leave
]
best_scores.append(
{
rouge_type: max(
[s[rouge_type] for s in cur_scores_leave_one],
key=lambda s: s.fmeasure,
)
for rouge_type in rouge_types
}
)
# average the leave-one-out bests to produce the final score
score = {
rouge_type: rouge_scorer.scoring.Score(
np.mean([b[rouge_type].precision for b in best_scores]),
np.mean([b[rouge_type].recall for b in best_scores]),
np.mean([b[rouge_type].fmeasure for b in best_scores]),
)
for rouge_type in rouge_types
}
else:
score = scorer.score(refs[0], pred)
# convert the named tuples to plain nested dicts
score = {
rouge_type: {
"precision": score[rouge_type].precision,
"recall": score[rouge_type].recall,
"fmeasure": score[rouge_type].fmeasure,
}
for rouge_type in rouge_types
}
return score
# stderr stuff # stderr stuff
class _bootstrap_internal: class _bootstrap_internal:
......
from . import gpt2 from . import gpt2
from . import gptj
from . import gpt3 from . import gpt3
from . import t5
from . import t0
from . import dummy from . import dummy
MODEL_REGISTRY = { MODEL_REGISTRY = {
"hf": gpt2.HFLM, "hf": gpt2.HFLM,
"gpt2": gpt2.GPT2LM, "gpt2": gpt2.GPT2LM,
"gptj": gptj.GPTJLM,
"gpt3": gpt3.GPT3LM, "gpt3": gpt3.GPT3LM,
"t5": t5.T5LM,
"mt5": t5.T5LM,
"t0": t0.T0LM,
"dummy": dummy.DummyLM, "dummy": dummy.DummyLM,
} }
......
...@@ -4,8 +4,15 @@ from lm_eval.base import BaseLM ...@@ -4,8 +4,15 @@ from lm_eval.base import BaseLM
class HFLM(BaseLM): class HFLM(BaseLM):
def __init__(
def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1): self,
device="cuda",
pretrained="gpt2",
revision="main",
subfolder=None,
tokenizer=None,
batch_size=1,
):
super().__init__() super().__init__()
assert isinstance(device, str) assert isinstance(device, str)
...@@ -15,28 +22,47 @@ class HFLM(BaseLM): ...@@ -15,28 +22,47 @@ class HFLM(BaseLM):
if device: if device:
self._device = torch.device(device) self._device = torch.device(device)
else: else:
self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
# TODO: update this to be less of a hack once subfolder is fixed in HF # TODO: update this to be less of a hack once subfolder is fixed in HF
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained( self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, revision=revision + ("/" + subfolder if subfolder is not None else "") pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
).to(self.device) ).to(self.device)
self.gpt2.eval() self.gpt2.eval()
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2 # pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, revision=revision, subfolder=subfolder) pretrained if tokenizer is None else tokenizer,
revision=revision,
subfolder=subfolder,
)
assert isinstance(self.tokenizer, ( assert isinstance(
transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast, self.tokenizer,
transformers.T5Tokenizer, transformers.T5TokenizerFast, (
)), "this tokenizer has not been checked for compatibility yet!" transformers.GPT2Tokenizer,
transformers.GPT2TokenizerFast,
transformers.T5Tokenizer,
transformers.T5TokenizerFast,
),
), "this tokenizer has not been checked for compatibility yet!"
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)): if isinstance(
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373], \ self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)
self.tokenizer.encode('hello\n\nhello') ):
assert self.tokenizer.encode("hello\n\nhello") == [
31373,
198,
198,
31373,
], self.tokenizer.encode("hello\n\nhello")
# multithreading and batching # multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size self.batch_size_per_gpu = batch_size # todo: adaptive batch size
...@@ -75,7 +101,7 @@ class HFLM(BaseLM): ...@@ -75,7 +101,7 @@ class HFLM(BaseLM):
def tok_encode(self, string: str): def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False) return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens): def tok_decode(self, tokens):
return self.tokenizer.decode(tokens) return self.tokenizer.decode(tokens)
...@@ -89,14 +115,42 @@ class HFLM(BaseLM): ...@@ -89,14 +115,42 @@ class HFLM(BaseLM):
""" """
with torch.no_grad(): with torch.no_grad():
return self.gpt2(inps)[0][:, :, :50257] return self.gpt2(inps)[0][:, :, :50257]
def _model_generate(self, context, max_length, eos_token_id): def _get_stopping_criteria(self, stopping_criteria_ids):
class MultitokenEOSCriteria(transformers.StoppingCriteria):
def __init__(self, eos_seq_id: torch.LongTensor, tokenizer):
self.eos_seq = tokenizer.decode(eos_seq_id)
self.eos_seq_id = eos_seq_id
self.eos_seq_len = len(eos_seq_id) + 1
self.tokenizer = tokenizer
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
last_token_id = input_ids[0, -self.eos_seq_len:]
last_tokens = self.tokenizer.decode(last_token_id)
is_stopped = self.eos_seq in last_tokens
return is_stopped
class EOSCriteria(transformers.StoppingCriteria):
def __init__(self, eos_token_id: torch.LongTensor):
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids[0,-1] == self.eos_token_id
return transformers.StoppingCriteriaList([
MultitokenEOSCriteria(stopping_criteria_ids, self.tokenizer),
EOSCriteria(self.tokenizer.eos_token)
])
def _model_generate(self, context, max_length, stopping_criteria_ids):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.gpt2.generate( return self.gpt2.generate(
context, context,
max_length=max_length, max_length=max_length,
eos_token_id=eos_token_id, stopping_criteria=stopping_criteria,
do_sample=False do_sample=False,
) )
# for backwards compatibility # for backwards compatibility
......
import transformers
import torch
from lm_eval.base import BaseLM
class GPTJLM(BaseLM):
def __init__(
self,
device="cuda",
batch_size=1,
):
super().__init__()
assert isinstance(device, str)
assert isinstance(batch_size, int)
if device:
self._device = torch.device(device)
else:
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
pretrained = "EleutherAI/gpt-j-6B"
self.gptj = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device)
self.gptj.eval()
# pretrained tokenizer for neo is broken for now so just hard-coding this to gptj
self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained)
self.vocab_size = self.tokenizer.vocab_size
# multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# TODO: fix multi-gpu
# gpus = torch.cuda.device_count()
# if gpus > 1:
# self.gptj = nn.DataParallel(self.gptj)
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
try:
return self.gptj.config.n_ctx
except AttributeError:
# gptneoconfig doesn't have n_ctx apparently
return self.gptj.config.max_position_embeddings
@property
def max_gen_toks(self):
return 256
@property
def batch_size(self):
# TODO: fix multi-gpu
return self.batch_size_per_gpu # * gpus
@property
def device(self):
# TODO: fix multi-gpu
return self._device
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
return self.gptj(inps)[0][:, :, :50257]
def _get_stopping_criteria(self, stopping_criteria_ids):
class MultitokenEOSCriteria(transformers.StoppingCriteria):
def __init__(self, eos_seq_id: torch.LongTensor, tokenizer):
self.eos_seq = tokenizer.decode(eos_seq_id)
self.eos_seq_id = eos_seq_id
self.eos_seq_len = len(eos_seq_id) + 1
self.tokenizer = tokenizer
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
last_token_id = input_ids[0, -self.eos_seq_len:]
last_tokens = self.tokenizer.decode(last_token_id)
is_stopped = self.eos_seq in last_tokens
return is_stopped
class EOSCriteria(transformers.StoppingCriteria):
def __init__(self, eos_token_id: torch.LongTensor):
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids[0,-1] == self.eos_token_id
return transformers.StoppingCriteriaList([
MultitokenEOSCriteria(stopping_criteria_ids, self.tokenizer),
EOSCriteria(self.tokenizer.eos_token)
])
def _model_generate(self, context, max_length, stopping_criteria_ids):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.gptj.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
from lm_eval.base import LM
from lm_eval import utils
from tqdm import tqdm
import numpy as np
import math
class T0LM(LM):
MAX_GEN_TOKS = 256
MAX_INP_LENGTH = 512
VOCAB_SIZE = 32100
EOT_TOKEN_ID = 1
def __init__(self, device='cuda', parallelize=False, pretrained='t0', batch_size=1):
super().__init__()
if device:
self.device = torch.device(device)
else:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(pretrained)
self.t0 = transformers.AutoModelForSeq2SeqLM.from_pretrained(pretrained)
self.t0.eval()
if parallelize == "True":
print(parallelize)
self.t0.parallelize()
self.device = torch.device('cuda:0')
else:
self.t0.to(self.device)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained)
self.max_length = self.MAX_INP_LENGTH
self.batch_size = int(batch_size)
@classmethod
def create_from_arg_string(cls, arg_string, additional_config={}):
args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(**args, **args2)
def loglikelihood(self, requests):
res = []
for chunk in tqdm(utils.chunks(requests, self.batch_size), total=math.ceil(len(requests)/self.batch_size)):
inputs, targets = zip(*chunk)
inputs_tok = self.tokenizer(
list(inputs),
max_length=self.max_length,
padding=True,
# truncation=True,
add_special_tokens=False,
return_tensors="pt"
).to(self.device)
for key in inputs_tok:
inputs_tok[key] = inputs_tok[key][:, -(self.max_length - 1) :]
targets_tok = self.tokenizer(
list(targets),
max_length=self.MAX_GEN_TOKS,
padding=True,
# truncation=True,
add_special_tokens=False,
return_tensors="pt"
).to(self.device)
for key in targets_tok:
targets_tok[key] = targets_tok[key][:, -(self.max_length - 1) :]
with torch.no_grad():
outputs = self.t0(
**inputs_tok,
labels=targets_tok["input_ids"]
)
log_softmaxes = F.log_softmax(outputs.logits, dim=-1)
output_iterator = zip(
chunk,
log_softmaxes,
targets_tok["input_ids"],
targets_tok["attention_mask"],
)
for cache_key, log_softmax, target_tok, target_mask in output_iterator:
length = target_mask.sum()
log_softmax = log_softmax[:length]
target_tok = target_tok[:length]
greedy_tokens = log_softmax.argmax(dim=-1)
max_equal = (greedy_tokens == target_tok).all()
target_logits = torch.gather(
log_softmax, 1, target_tok.unsqueeze(-1)
).squeeze(-1)
answer = (float(target_logits.sum()), bool(max_equal))
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
res.append(answer)
return res
def loglikelihood_rolling(self, requests):
raise NotImplementedError
def _get_stopping_criteria(self, stopping_criteria_ids):
class MultitokenEOSCriteria(transformers.StoppingCriteria):
def __init__(self, eos_seq_id: torch.LongTensor, tokenizer):
self.eos_seq = tokenizer.decode(eos_seq_id)
self.eos_seq_id = eos_seq_id
self.eos_seq_len = len(eos_seq_id) + 1
self.tokenizer = tokenizer
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
last_token_id = input_ids[0, -self.eos_seq_len:]
last_tokens = self.tokenizer.decode(last_token_id)
is_stopped = self.eos_seq in last_tokens
return is_stopped
class EOSCriteria(transformers.StoppingCriteria):
def __init__(self, eos_token_id: torch.LongTensor):
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids[0,-1] == self.eos_token_id
return transformers.StoppingCriteriaList([
MultitokenEOSCriteria(stopping_criteria_ids, self.tokenizer),
EOSCriteria(self.tokenizer.eos_token)
])
def greedy_until(self, requests):
res = []
for context, until in tqdm(requests):
if isinstance(until, str): until = [until]
context_enc = self.tokenizer(context, return_tensors="pt").to(self.device).input_ids
stopping_criteria_ids = self.tokenizer.encode(until[0])
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
cont = self.t0.generate(
context_enc,
max_length=self.MAX_GEN_TOKS,
stopping_criteria=stopping_criteria,
do_sample=False
)
s = self.tokenizer.decode(cont[0].tolist())
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s)
return res
\ No newline at end of file
import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
from lm_eval.base import LM
from lm_eval import utils
from tqdm import tqdm
import numpy as np
import math
class T5LM(LM):
MAX_GEN_TOKS = 256
MAX_INP_LENGTH = 512
VOCAB_SIZE = 32128
EOT_TOKEN_ID = 1
def __init__(self, device='cuda', parallelize=False, pretrained='t5', batch_size=1):
super().__init__()
if device:
self.device = torch.device(device)
else:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(pretrained)
self.t5 = transformers.AutoModelForSeq2SeqLM.from_pretrained(pretrained)
self.t5.eval()
if parallelize == "True":
print(parallelize)
self.t5.parallelize()
self.device = torch.device('cuda:0')
else:
self.t5.to(self.device)
self.tokenizer = transformers.T5TokenizerFast.from_pretrained(pretrained)
self.max_length = self.MAX_INP_LENGTH
self.batch_size = int(batch_size)
@classmethod
def create_from_arg_string(cls, arg_string, additional_config={}):
args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(**args, **args2)
def loglikelihood(self, requests):
res = []
for chunk in tqdm(utils.chunks(requests, self.batch_size), total=math.ceil(len(requests)/self.batch_size)):
inputs, targets = zip(*chunk)
inputs_tok = self.tokenizer(
list(inputs),
max_length=self.max_length,
padding=True,
# truncation=True,
add_special_tokens=False,
return_tensors="pt"
).to(self.device)
for key in inputs_tok:
inputs_tok[key] = inputs_tok[key][:, -(self.max_length - 1) :]
targets_tok = self.tokenizer(
list(targets),
max_length=self.MAX_GEN_TOKS,
padding=True,
# truncation=True,
add_special_tokens=False,
return_tensors="pt"
).to(self.device)
for key in targets_tok:
targets_tok[key] = targets_tok[key][:, -(self.max_length - 1) :]
with torch.no_grad():
outputs = self.t5(
**inputs_tok,
labels=targets_tok["input_ids"]
)
log_softmaxes = F.log_softmax(outputs.logits, dim=-1)
output_iterator = zip(
chunk,
log_softmaxes,
targets_tok["input_ids"],
targets_tok["attention_mask"],
)
for cache_key, log_softmax, target_tok, target_mask in output_iterator:
length = target_mask.sum()
log_softmax = log_softmax[:length]
target_tok = target_tok[:length]
greedy_tokens = log_softmax.argmax(dim=-1)
max_equal = (greedy_tokens == target_tok).all()
target_logits = torch.gather(
log_softmax, 1, target_tok.unsqueeze(-1)
).squeeze(-1)
answer = (float(target_logits.sum()), bool(max_equal))
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
res.append(answer)
return res
def loglikelihood_rolling(self, requests):
raise NotImplementedError
def _get_stopping_criteria(self, stopping_criteria_ids):
class MultitokenEOSCriteria(transformers.StoppingCriteria):
def __init__(self, eos_seq_id: torch.LongTensor, tokenizer):
self.eos_seq = tokenizer.decode(eos_seq_id)
self.eos_seq_id = eos_seq_id
self.eos_seq_len = len(eos_seq_id) + 1
self.tokenizer = tokenizer
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
last_token_id = input_ids[0, -self.eos_seq_len:]
last_tokens = self.tokenizer.decode(last_token_id)
is_stopped = self.eos_seq in last_tokens
return is_stopped
class EOSCriteria(transformers.StoppingCriteria):
def __init__(self, eos_token_id: torch.LongTensor):
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids[0,-1] == self.eos_token_id
return transformers.StoppingCriteriaList([
MultitokenEOSCriteria(stopping_criteria_ids, self.tokenizer),
EOSCriteria(self.tokenizer.eos_token)
])
def greedy_until(self, requests):
res = []
for context, until in tqdm(requests):
if isinstance(until, str): until = [until]
context_enc = self.tokenizer(context, return_tensors="pt").to(self.device).input_ids
stopping_criteria_ids = self.tokenizer.encode(until[0])
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
cont = self.t5.generate(
context_enc,
max_length=self.MAX_GEN_TOKS,
stopping_criteria=stopping_criteria,
do_sample=False
)
s = self.tokenizer.decode(cont[0].tolist())
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s)
return res
\ No newline at end of file
from promptsource.templates import DatasetTemplates
from pprint import pprint from pprint import pprint
from typing import List, Union from typing import List, Union
...@@ -51,6 +52,9 @@ from . import blimp ...@@ -51,6 +52,9 @@ from . import blimp
from . import asdiv from . import asdiv
from . import gsm8k from . import gsm8k
from . import storycloze from . import storycloze
from . import hans
# from . import e2e_nlg_cleaned
######################################## ########################################
# Translation tasks # Translation tasks
...@@ -58,8 +62,8 @@ from . import storycloze ...@@ -58,8 +62,8 @@ from . import storycloze
# 6 total # 6 total
gpt3_translation_benchmarks = { gpt3_translation_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French "wmt14": ["en-fr", "fr-en"], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian "wmt16": ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian
} }
...@@ -67,7 +71,7 @@ gpt3_translation_benchmarks = { ...@@ -67,7 +71,7 @@ gpt3_translation_benchmarks = {
selected_translation_benchmarks = { selected_translation_benchmarks = {
**gpt3_translation_benchmarks, **gpt3_translation_benchmarks,
"wmt20": sacrebleu.get_langpairs_for_testset("wmt20"), "wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
"iwslt17": ['en-ar', 'ar-en'] # Arabic "iwslt17": ["en-ar", "ar-en"], # Arabic
} }
# 319 total # 319 total
...@@ -91,7 +95,7 @@ TASK_REGISTRY = { ...@@ -91,7 +95,7 @@ TASK_REGISTRY = {
"rte": glue.RTE, "rte": glue.RTE,
"qnli": glue.QNLI, "qnli": glue.QNLI,
"qqp": glue.QQP, "qqp": glue.QQP,
#"stsb": glue.STSB, # not implemented yet # "stsb": glue.STSB, # not implemented yet
"sst": glue.SST, "sst": glue.SST,
"wnli": glue.WNLI, "wnli": glue.WNLI,
# SuperGLUE # SuperGLUE
...@@ -102,34 +106,27 @@ TASK_REGISTRY = { ...@@ -102,34 +106,27 @@ TASK_REGISTRY = {
"record": superglue.ReCoRD, "record": superglue.ReCoRD,
"wic": superglue.WordsInContext, "wic": superglue.WordsInContext,
"wsc": superglue.SGWinogradSchemaChallenge, "wsc": superglue.SGWinogradSchemaChallenge,
# Order by benchmark/genre? # Order by benchmark/genre?
"coqa": coqa.CoQA, "coqa": coqa.CoQA,
"drop": drop.DROP, "drop": drop.DROP,
"lambada": lambada.LAMBADA, "lambada": lambada.LAMBADA,
"lambada_cloze": lambada_cloze.LAMBADA_cloze, "lambada_cloze": lambada_cloze.LAMBADA_cloze,
# multilingual lambada # multilingual lambada
**lambada_multilingual.construct_tasks(), **lambada_multilingual.construct_tasks(),
"wikitext": wikitext.WikiText, "wikitext": wikitext.WikiText,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix # "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix # "cbt-ne": cbt.CBTNE, # disabled pending context length fix
"piqa": piqa.PiQA, "piqa": piqa.PiQA,
"prost": prost.PROST, "prost": prost.PROST,
"mc_taco": mc_taco.MCTACO, "mc_taco": mc_taco.MCTACO,
# Science related # Science related
"pubmedqa" : pubmedqa.Pubmed_QA, "pubmedqa": pubmedqa.Pubmed_QA,
"sciq" : sciq.SciQ, "sciq": sciq.SciQ,
# "e2e_nlg_cleaned": e2e_nlg_cleaned.E2E_NLG_Cleaned,
"qasper": qasper.QASPER, "qasper": qasper.QASPER,
"qa4mre_2011": qa4mre.QA4MRE_2011,
"qa4mre_2011" : qa4mre.QA4MRE_2011, "qa4mre_2012": qa4mre.QA4MRE_2012,
"qa4mre_2012" : qa4mre.QA4MRE_2012, "qa4mre_2013": qa4mre.QA4MRE_2013,
"qa4mre_2013" : qa4mre.QA4MRE_2013,
"triviaqa": triviaqa.TriviaQA, "triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy, "arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge, "arc_challenge": arc.ARCChallenge,
...@@ -140,7 +137,7 @@ TASK_REGISTRY = { ...@@ -140,7 +137,7 @@ TASK_REGISTRY = {
"squad2": squad.SQuAD2, "squad2": squad.SQuAD2,
"race": race.RACE, "race": race.RACE,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet # "naturalqs": naturalqs.NaturalQs, # not implemented yet
"headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es "headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es
"headqa_es": headqa.HeadQAEs, "headqa_es": headqa.HeadQAEs,
"headqa_en": headqa.HeadQAEn, "headqa_en": headqa.HeadQAEn,
"mathqa": mathqa.MathQA, "mathqa": mathqa.MathQA,
...@@ -150,21 +147,18 @@ TASK_REGISTRY = { ...@@ -150,21 +147,18 @@ TASK_REGISTRY = {
"anli_r1": anli.ANLIRound1, "anli_r1": anli.ANLIRound1,
"anli_r2": anli.ANLIRound2, "anli_r2": anli.ANLIRound2,
"anli_r3": anli.ANLIRound3, "anli_r3": anli.ANLIRound3,
"hans": hans.HANS,
"ethics_cm": hendrycks_ethics.EthicsCM, "ethics_cm": hendrycks_ethics.EthicsCM,
"ethics_deontology": hendrycks_ethics.EthicsDeontology, "ethics_deontology": hendrycks_ethics.EthicsDeontology,
"ethics_justice": hendrycks_ethics.EthicsJustice, "ethics_justice": hendrycks_ethics.EthicsJustice,
"ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal, "ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal,
"ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism, "ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
"ethics_virtue": hendrycks_ethics.EthicsVirtue, "ethics_virtue": hendrycks_ethics.EthicsVirtue,
"truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
"truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice, "truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
"truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
# dialogue # dialogue
"mutual": mutual.MuTual, "mutual": mutual.MuTual,
"mutual_plus": mutual.MuTualPlus, "mutual_plus": mutual.MuTualPlus,
# math # math
"math_algebra": hendrycks_math.MathAlgebra, "math_algebra": hendrycks_math.MathAlgebra,
"math_counting_and_prob": hendrycks_math.MathCountingAndProbability, "math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
...@@ -175,7 +169,6 @@ TASK_REGISTRY = { ...@@ -175,7 +169,6 @@ TASK_REGISTRY = {
"math_precalc": hendrycks_math.MathPrecalculus, "math_precalc": hendrycks_math.MathPrecalculus,
"math_asdiv": asdiv.Asdiv, "math_asdiv": asdiv.Asdiv,
"gsm8k": gsm8k.GradeSchoolMath8K, "gsm8k": gsm8k.GradeSchoolMath8K,
# arithmetic # arithmetic
"arithmetic_2da": arithmetic.Arithmetic2DPlus, "arithmetic_2da": arithmetic.Arithmetic2DPlus,
"arithmetic_2ds": arithmetic.Arithmetic2DMinus, "arithmetic_2ds": arithmetic.Arithmetic2DMinus,
...@@ -189,22 +182,18 @@ TASK_REGISTRY = { ...@@ -189,22 +182,18 @@ TASK_REGISTRY = {
"arithmetic_1dc": arithmetic.Arithmetic1DComposite, "arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# TODO Perhaps make these groups of tasks # TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations # e.g. anli, arithmetic, openai_translations, harness_translations
# hendrycksTest (57 tasks) # hendrycksTest (57 tasks)
**hendrycks_test.create_all_tasks(), **hendrycks_test.create_all_tasks(),
# e.g. wmt14-fr-en # e.g. wmt14-fr-en
**translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks), **translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks),
# chef's selection, mostly wmt20 # chef's selection, mostly wmt20
**translation.create_tasks_from_benchmarks(selected_translation_benchmarks), **translation.create_tasks_from_benchmarks(selected_translation_benchmarks),
# Word Scrambling and Manipulation Tasks # Word Scrambling and Manipulation Tasks
"anagrams1": unscramble.Anagrams1, "anagrams1": unscramble.Anagrams1,
"anagrams2": unscramble.Anagrams2, "anagrams2": unscramble.Anagrams2,
"cycle_letters": unscramble.CycleLetters, "cycle_letters": unscramble.CycleLetters,
"random_insertion": unscramble.RandomInsertion, "random_insertion": unscramble.RandomInsertion,
"reversed_words": unscramble.ReversedWords, "reversed_words": unscramble.ReversedWords,
# Pile # Pile
"pile_arxiv": pile.PileArxiv, "pile_arxiv": pile.PileArxiv,
"pile_books3": pile.PileBooks3, "pile_books3": pile.PileBooks3,
...@@ -228,7 +217,6 @@ TASK_REGISTRY = { ...@@ -228,7 +217,6 @@ TASK_REGISTRY = {
"pile_ubuntu-irc": pile.PileUbuntuIrc, "pile_ubuntu-irc": pile.PileUbuntuIrc,
"pile_wikipedia": pile.PileWikipedia, "pile_wikipedia": pile.PileWikipedia,
"pile_youtubesubtitles": pile.PileYoutubeSubtitles, "pile_youtubesubtitles": pile.PileYoutubeSubtitles,
# BLiMP # BLiMP
"blimp_adjunct_island": blimp.BlimpAdjunctIsland, "blimp_adjunct_island": blimp.BlimpAdjunctIsland,
"blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement, "blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement,
...@@ -297,7 +285,6 @@ TASK_REGISTRY = { ...@@ -297,7 +285,6 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance, "blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
"blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap, "blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
"blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance, "blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
# Requires manual download of data. # Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016, # "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018, # "storycloze_2018": storycloze.StoryCloze2018,
...@@ -321,19 +308,51 @@ def get_task_name_from_object(task_object): ...@@ -321,19 +308,51 @@ def get_task_name_from_object(task_object):
for name, class_ in TASK_REGISTRY.items(): for name, class_ in TASK_REGISTRY.items():
if class_ is task_object: if class_ is task_object:
return name return name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return task_object.EVAL_HARNESS_NAME if hasattr(task_object, "EVAL_HARNESS_NAME") else type(task_object).__name__ return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]): def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]):
task_name_dict = { task_name_dict = {
task_name: get_task(task_name)() task_name: get_task(task_name)()
for task_name in task_name_list if isinstance(task_name, str) for task_name in task_name_list
if isinstance(task_name, str)
} }
task_name_from_object_dict = { task_name_from_object_dict = {
get_task_name_from_object(task_object): task_object get_task_name_from_object(task_object): task_object
for task_object in task_name_list if not isinstance(task_object, str) for task_object in task_name_list
if not isinstance(task_object, str)
} }
assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys())) assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {**task_name_dict, **task_name_from_object_dict} return {**task_name_dict, **task_name_from_object_dict}
def get_task_dict_promptsource(task_name_list: List[str]):
"""Loads a task instance for each prompt written for that task."""
task_name_dict = {}
for task_name in task_name_list:
assert isinstance(task_name, str)
# Static version of the Task Use this to get HF dataset path / name.
static_task_obj = get_task(task_name)
# Create the proper task name arg for DatasetTemplates.
sub_task = (
f"/{static_task_obj.DATASET_NAME}" if static_task_obj.DATASET_NAME else ""
)
ps_task_name = f"{static_task_obj.DATASET_PATH}{sub_task}"
task_prompts = DatasetTemplates(ps_task_name)
for prompt_name in task_prompts.all_template_names:
prompt = task_prompts[prompt_name]
# NOTE: We choose a sep that can be easily split.
task_name_dict[f"{task_name}+{prompt_name}"] = get_task(task_name)(
prompt=prompt
)
return task_name_dict
...@@ -10,7 +10,7 @@ provided explanations. ...@@ -10,7 +10,7 @@ provided explanations.
Homepage: "https://github.com/facebookresearch/anli" Homepage: "https://github.com/facebookresearch/anli"
""" """
import numpy as np import numpy as np
from lm_eval.base import rf, Task from lm_eval.base import rf, PromptSourceTask
from lm_eval.metrics import mean from lm_eval.metrics import mean
...@@ -30,7 +30,7 @@ _CITATION = """ ...@@ -30,7 +30,7 @@ _CITATION = """
""" """
class ANLIBase(Task): class ANLIBase(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "anli" DATASET_PATH = "anli"
DATASET_NAME = None DATASET_NAME = None
...@@ -59,51 +59,6 @@ class ANLIBase(Task): ...@@ -59,51 +59,6 @@ class ANLIBase(Task):
if self.has_test_docs(): if self.has_test_docs():
return self.dataset["test_r" + str(self.SPLIT)] return self.dataset["test_r" + str(self.SPLIT)]
def doc_to_text(self, doc):
# OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning
# of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly
# appended onto the question, with no "Answer:" or even a newline. Do we *really*
# want to do it exactly as OA did?
return doc['premise'] + '\nQuestion: ' + doc['hypothesis'] + ' True, False, or Neither?\nAnswer:'
def doc_to_target(self, doc):
# True = entailment
# False = contradiction
# Neither = neutral
return " " + ["True", "Neither", "False"][doc['label']]
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
ll_true, _ = rf.loglikelihood(ctx, " True")
ll_neither, _ = rf.loglikelihood(ctx, " Neither")
ll_false, _ = rf.loglikelihood(ctx, " False")
return ll_true, ll_neither, ll_false
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
gold = doc["label"]
pred = np.argmax(results)
return {
"acc": pred == gold
}
def aggregation(self): def aggregation(self):
""" """
:returns: {str: [float] -> float} :returns: {str: [float] -> float}
......
...@@ -58,10 +58,11 @@ class Arithmetic(Task): ...@@ -58,10 +58,11 @@ class Arithmetic(Task):
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll, is_prediction = rf.loglikelihood(ctx, doc["completion"]) ll, is_prediction = rf.loglikelihood(ctx, doc["completion"])
return is_prediction return ll, is_prediction
def process_results(self, doc, results): def process_results(self, doc, results):
is_prediction, = results print(results)
results = results
return { return {
"acc": is_prediction "acc": is_prediction
} }
......
...@@ -12,7 +12,7 @@ Homepage: https://stanfordnlp.github.io/coqa/ ...@@ -12,7 +12,7 @@ Homepage: https://stanfordnlp.github.io/coqa/
import inspect import inspect
import transformers.data.metrics.squad_metrics as squad_metrics import transformers.data.metrics.squad_metrics as squad_metrics
import lm_eval.datasets.coqa.coqa import lm_eval.datasets.coqa.coqa
from lm_eval.base import Task, rf, mean from lm_eval.base import PromptSourceTask, Task, rf, mean
from itertools import zip_longest from itertools import zip_longest
...@@ -28,9 +28,9 @@ _CITATION = """ ...@@ -28,9 +28,9 @@ _CITATION = """
""" """
class CoQA(Task): class CoQA(PromptSourceTask):
VERSION = 1 VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.coqa.coqa) DATASET_PATH = "coqa"
DATASET_NAME = None DATASET_NAME = None
def has_training_docs(self): def has_training_docs(self):
...@@ -51,44 +51,21 @@ class CoQA(Task): ...@@ -51,44 +51,21 @@ class CoQA(Task):
def test_docs(self): def test_docs(self):
pass pass
def doc_to_text(self, doc): # @classmethod
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1} # def get_answers(cls, doc, turn_id):
# and a question qi, the task is to predict the answer ai # # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
doc_text = doc["story"] + '\n\n' # answers = []
for (q, a) in zip_longest(doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]): # omit target answer ai # answer_forturn = doc["answers"]["input_text"][turn_id - 1]
question = f"Q: {q}\n\n" # answers.append(answer_forturn)
answer = f"A: {a}\n\n" if a is not None else "A:" # additional_answers = doc.get("additional_answers")
doc_text += question + answer # if additional_answers:
return doc_text # for key in additional_answers:
# additional_answer_for_turn = additional_answers[key]["input_text"][
@classmethod # turn_id - 1
def get_answers(cls, doc, turn_id): # ]
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers). # if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers = [] # answers.append(additional_answer_for_turn)
answer_forturn = doc["answers"]["input_text"][turn_id - 1] # return answers
answers.append(answer_forturn)
additional_answers = doc.get("additional_answers")
if additional_answers:
for key in additional_answers:
additional_answer_for_turn = additional_answers[key]["input_text"][turn_id - 1]
if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn)
return answers
@classmethod
def get_answer_choice(self, raw_text):
# Function maps answers to CoQA answer categories
# ~ 1/5 of the CoQA answers are Yes/No
# ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch)
if raw_text == "unknown":
return '0'
if squad_metrics.normalize_answer(raw_text) == "yes":
return '1'
if squad_metrics.normalize_answer(raw_text) == "no":
return '2'
return '3' # Not a yes/no question
@staticmethod @staticmethod
def compute_scores(gold_list, pred): def compute_scores(gold_list, pred):
...@@ -98,40 +75,40 @@ class CoQA(Task): ...@@ -98,40 +75,40 @@ class CoQA(Task):
em_sum = 0.0 em_sum = 0.0
if len(gold_list) > 1: if len(gold_list) > 1:
for i in range(len(gold_list)): for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1:] gold_answers = gold_list[0:i] + gold_list[i + 1 :]
# predictions compared against (n) golds and take maximum # predictions compared against (n) golds and take maximum
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers) em_sum += max(
squad_metrics.compute_exact(a, pred) for a in gold_answers
)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers) f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else: else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list) em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list) f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)
return {'em': em_sum / max(1, len(gold_list)), 'f1': f1_sum / max(1, len(gold_list))} return {
"em": em_sum / max(1, len(gold_list)),
"f1": f1_sum / max(1, len(gold_list)),
}
def doc_to_target(self, doc, turnid=None): def stopping_criteria(self):
# Default to prediction of last turn. return "\n\n"
if turnid is None:
turnid = len(doc["questions"]["input_text"])
raw_text = doc['answers']["input_text"][turnid - 1]
return " " + raw_text
def construct_requests(self, doc, ctx): # def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of # """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. # Requests which will be sent to the LM.
:param doc: # :param doc:
The document as returned from training_docs, validation_docs, or test_docs. # The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str # :param ctx: str
The context string, generated by fewshot_context. This includes the natural # The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question # language description, as well as the few shot examples, and the question
part of the document for `doc`. # part of the document for `doc`.
""" # """
cont_request = rf.greedy_until(ctx, ['\nQ:']) # return cont_request
return cont_request
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
:param doc: :param doc:
...@@ -139,15 +116,25 @@ class CoQA(Task): ...@@ -139,15 +116,25 @@ class CoQA(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
turn_id = len(doc["questions"]["input_text"]) target = self.doc_to_target(doc).strip()
gold_list = self.get_answers(doc, turn_id) pred = results[0].strip().split("\n")[0]
pred = results[0].strip().split('\n')[0] print("*" * 80)
print(f"DOC: {doc}")
scores = self.compute_scores(gold_list, pred) # print(f"PS: {self.prompt.apply(doc)}")
print(f"TEXT: {self.doc_to_text(doc)}")
print(f"TARGET: {target} END TARGET")
print(f"PRED: {pred} END PRED")
print("*" * 80)
# turn_id = len(doc["questions"]["input_text"])
# gold_list = self.get_answers(doc, turn_id)
# TODO: Add HF metrics mapped from promptsource metadata.
scores = self.compute_scores([target], pred)
return { return {
"f1": scores['f1'], "f1": scores["f1"],
"em": scores['em'], "em": scores["em"],
} }
def higher_is_better(self): def higher_is_better(self):
......
...@@ -18,7 +18,7 @@ import re ...@@ -18,7 +18,7 @@ import re
import string import string
import lm_eval.datasets.drop.drop import lm_eval.datasets.drop.drop
from scipy.optimize import linear_sum_assignment from scipy.optimize import linear_sum_assignment
from lm_eval.base import Task, rf from lm_eval.base import PromptSourceTask, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
...@@ -37,9 +37,9 @@ _CITATION = """ ...@@ -37,9 +37,9 @@ _CITATION = """
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE) _ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
class DROP(Task): class DROP(PromptSourceTask):
VERSION = 1 VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.drop.drop) DATASET_PATH = "drop" # inspect.getfile(lm_eval.datasets.drop.drop)
DATASET_NAME = None DATASET_NAME = None
def has_training_docs(self): def has_training_docs(self):
...@@ -52,46 +52,13 @@ class DROP(Task): ...@@ -52,46 +52,13 @@ class DROP(Task):
return False return False
def training_docs(self): def training_docs(self):
if self._training_docs is None: # if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"])) # self._training_docs = list()
return self._training_docs # return self._training_docs
return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
return map(self._process_doc, self.dataset["validation"]) return self.dataset["validation"]
def _process_doc(self, doc):
return {
"id": doc["query_id"],
"passage": doc["passage"],
"question": doc["question"],
"answers": self.get_answers(doc),
}
@classmethod
def get_answers(cls, qa):
def _flatten_validated_answers(validated_answers):
""" Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
vas = []
for i in range(len(validated_answers["number"])):
vas.append({
"number": validated_answers["number"][i],
"date": validated_answers["date"][i],
"spans": validated_answers["spans"][i],
})
return vas
answers = []
answers_set = set()
candidates = [qa["answer"]] + _flatten_validated_answers(qa["validated_answers"])
for candidate in candidates:
answer = cls.parse_answer(candidate)
if answer in answers_set:
continue
answers_set.add(answer)
answers.append(answer)
return answers
@classmethod @classmethod
def parse_answer(cls, answer): def parse_answer(cls, answer):
...@@ -100,29 +67,33 @@ class DROP(Task): ...@@ -100,29 +67,33 @@ class DROP(Task):
return (str(answer["number"]),) return (str(answer["number"]),)
if answer["spans"] != []: if answer["spans"] != []:
return tuple(answer["spans"]) return tuple(answer["spans"])
return (" ".join([answer["date"]["day"], return (
answer["date"]["month"], " ".join(
answer["date"]["year"]]).strip(),) [answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]]
).strip(),
)
def doc_to_text(self, doc): # def doc_to_text(self, doc):
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:" # return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
def doc_to_target(self, doc): # def doc_to_target(self, doc):
return " " + ", ".join(doc["answers"][0]) # return " " + ", ".join(doc["answers"][0])
def construct_requests(self, doc, ctx): # def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of # """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. # Requests which will be sent to the LM.
:param doc: # :param doc:
The document as returned from training_docs, validation_docs, or test_docs. # The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str # :param ctx: str
The context string, generated by fewshot_context. This includes the natural # The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question # language description, as well as the few shot examples, and the question
part of the document for `doc`. # part of the document for `doc`.
""" # """
conts = [rf.greedy_until(ctx, ["."])] # conts = [rf.greedy_until(ctx, ["."])]
return conts # return conts
def stopping_criteria(self):
return "."
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -134,7 +105,21 @@ class DROP(Task): ...@@ -134,7 +105,21 @@ class DROP(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
preds, golds = results, doc["answers"]
pred = results[0].strip()
target = self.doc_to_target(doc).strip()
print("*" * 80)
print(f"DOC: {doc}")
print(f"PS: {self.prompt.apply(doc)}")
print(f"TEXT: {self.doc_to_text(doc)}")
print(f"TARGET: {target} END TARGET")
print(f"PRED: {pred} END PRED")
print("*" * 80)
preds = [pred]
golds = [target]
max_em = 0 max_em = 0
max_f1 = 0 max_f1 = 0
for gold_answer in golds: for gold_answer in golds:
...@@ -142,10 +127,7 @@ class DROP(Task): ...@@ -142,10 +127,7 @@ class DROP(Task):
if gold_answer[0].strip(): if gold_answer[0].strip():
max_em = max(max_em, exact_match) max_em = max(max_em, exact_match)
max_f1 = max(max_f1, f1_score) max_f1 = max(max_f1, f1_score)
return { return {"em": max_em, "f1": max_f1}
"em": max_em,
"f1": max_f1
}
def get_metrics(self, predicted, gold): def get_metrics(self, predicted, gold):
""" """
...@@ -158,7 +140,9 @@ class DROP(Task): ...@@ -158,7 +140,9 @@ class DROP(Task):
predicted_bags = self._answer_to_bags(predicted) predicted_bags = self._answer_to_bags(predicted)
gold_bags = self._answer_to_bags(gold) gold_bags = self._answer_to_bags(gold)
if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]): if set(predicted_bags[0]) == set(gold_bags[0]) and len(
predicted_bags[0]
) == len(gold_bags[0]):
exact_match = 1.0 exact_match = 1.0
else: else:
exact_match = 0.0 exact_match = 0.0
...@@ -190,7 +174,9 @@ class DROP(Task): ...@@ -190,7 +174,9 @@ class DROP(Task):
for gold_index, gold_item in enumerate(gold): for gold_index, gold_item in enumerate(gold):
for pred_index, pred_item in enumerate(predicted): for pred_index, pred_item in enumerate(predicted):
if self._match_numbers_if_present(gold_item, pred_item): if self._match_numbers_if_present(gold_item, pred_item):
scores[gold_index, pred_index] = self._compute_f1(pred_item, gold_item) scores[gold_index, pred_index] = self._compute_f1(
pred_item, gold_item
)
row_ind, col_ind = linear_sum_assignment(-scores) row_ind, col_ind = linear_sum_assignment(-scores)
max_scores = np.zeros([max(len(gold), len(predicted))]) max_scores = np.zeros([max(len(gold), len(predicted))])
...@@ -256,7 +242,11 @@ class DROP(Task): ...@@ -256,7 +242,11 @@ class DROP(Task):
def _normalize(self, answer): def _normalize(self, answer):
tokens = [ tokens = [
self._white_space_fix(self._remove_articles(self._fix_number(self._remove_punc(token.lower())))) self._white_space_fix(
self._remove_articles(
self._fix_number(self._remove_punc(token.lower()))
)
)
for token in self._tokenize(answer) for token in self._tokenize(answer)
] ]
tokens = [token for token in tokens if token.strip()] tokens = [token for token in tokens if token.strip()]
...@@ -269,10 +259,7 @@ class DROP(Task): ...@@ -269,10 +259,7 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"em": mean, "f1": mean}
"em": mean,
"f1": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -280,7 +267,4 @@ class DROP(Task): ...@@ -280,7 +267,4 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"em": True, "f1": True}
"em": True,
"f1": True
}
...@@ -14,7 +14,7 @@ respect to a wide range of linguistic phenomena found in natural language. ...@@ -14,7 +14,7 @@ respect to a wide range of linguistic phenomena found in natural language.
Homepage: https://gluebenchmark.com/ Homepage: https://gluebenchmark.com/
""" """
import numpy as np import numpy as np
from lm_eval.base import rf, Task from lm_eval.base import PromptSourceTask, rf, Task
from lm_eval.metrics import mean, matthews_corrcoef, f1_score, yesno from lm_eval.metrics import mean, matthews_corrcoef, f1_score, yesno
from lm_eval.utils import general_detokenize from lm_eval.utils import general_detokenize
...@@ -45,7 +45,7 @@ _CITATION = """ ...@@ -45,7 +45,7 @@ _CITATION = """
# Single-Sentence Tasks # Single-Sentence Tasks
class CoLA(Task): class CoLA(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "cola" DATASET_NAME = "cola"
...@@ -67,37 +67,20 @@ class CoLA(Task): ...@@ -67,37 +67,20 @@ class CoLA(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc): # def process_results(self, doc, results):
return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(doc["sentence"]) # answer_choices_list = self.prompt.get_answer_choices_list(doc)
# pred = np.argmax(results)
def doc_to_target(self, doc): # target = answer_choices_list.index(self.doc_to_target(doc).strip())
return " {}".format({1: "yes", 0: "no"}[doc["label"]]) # return {"mcc": (target, pred)}
def construct_requests(self, doc, ctx):
ll_true, _ = rf.loglikelihood(ctx, " yes")
ll_false, _ = rf.loglikelihood(ctx, " no")
return ll_true, ll_false
def process_results(self, doc, results):
ll_true, ll_false = results
pred = ll_true > ll_false
gold = doc["label"]
return {
"mcc": (gold, pred)
}
def higher_is_better(self): # def higher_is_better(self):
return { # return {"mcc": True}
"mcc": True
}
def aggregation(self): # def aggregation(self):
return { # return {"mcc": matthews_corrcoef}
"mcc": matthews_corrcoef
}
class SST(Task): class SST(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "sst2" DATASET_NAME = "sst2"
...@@ -119,42 +102,11 @@ class SST(Task): ...@@ -119,42 +102,11 @@ class SST(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
return "{}\nQuestion: Is this sentence positive or negative?\nAnswer:".format(
general_detokenize(doc["sentence"]),
)
def doc_to_target(self, doc):
return " {}".format({1: "positive", 0: "negative"}[doc["label"]])
def construct_requests(self, doc, ctx):
ll_positive, _ = rf.loglikelihood(ctx, " positive")
ll_negative, _ = rf.loglikelihood(ctx, " negative")
return ll_positive, ll_negative
def process_results(self, doc, results):
ll_positive, ll_negative = results
pred = ll_positive > ll_negative
gold = doc["label"]
return {
"acc": pred == gold
}
def higher_is_better(self):
return {
"acc": True
}
def aggregation(self):
return {
"acc": mean
}
# Inference Tasks # Inference Tasks
class MNLI(Task): class MNLI(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "mnli" DATASET_NAME = "mnli"
...@@ -181,41 +133,6 @@ class MNLI(Task): ...@@ -181,41 +133,6 @@ class MNLI(Task):
if self.has_test_docs(): if self.has_test_docs():
return self.dataset["test_matched"] return self.dataset["test_matched"]
def doc_to_text(self, doc):
return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format(
doc["premise"],
doc["hypothesis"].strip() + ('' if doc["hypothesis"].strip().endswith('.') else '.'),
)
def doc_to_target(self, doc):
# True = entailment
# False = contradiction
# Neither = neutral
return " {}".format({0: "True", 1: "Neither", 2: "False"}[doc["label"]])
def construct_requests(self, doc, ctx):
ll_true, _ = rf.loglikelihood(ctx, " True")
ll_neither, _ = rf.loglikelihood(ctx, " Neither")
ll_false, _ = rf.loglikelihood(ctx, " False")
return ll_true, ll_neither, ll_false
def process_results(self, doc, results):
gold = doc["label"]
pred = np.argmax(results)
return {
"acc": pred == gold
}
def higher_is_better(self):
return {
"acc": True
}
def aggregation(self):
return {
"acc": mean
}
class MNLIMismatched(MNLI): class MNLIMismatched(MNLI):
VERSION = 0 VERSION = 0
...@@ -229,7 +146,7 @@ class MNLIMismatched(MNLI): ...@@ -229,7 +146,7 @@ class MNLIMismatched(MNLI):
return self.dataset["test_mismatched"] return self.dataset["test_mismatched"]
class QNLI(Task): class QNLI(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "qnli" DATASET_NAME = "qnli"
...@@ -251,42 +168,8 @@ class QNLI(Task): ...@@ -251,42 +168,8 @@ class QNLI(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
return "{}\n{}\nQuestion: Does this response answer the question?\nAnswer:".format(
doc["question"],
doc["sentence"],
)
def doc_to_target(self, doc):
# True = entailment
# False = not entailment
return " {}".format({0: "yes", 1: "no"}[doc["label"]])
def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, " no")
return ll_yes, ll_no
def process_results(self, doc, results):
ll_yes, ll_no = results
pred = ll_no > ll_yes
gold = doc["label"]
return {
"acc": pred == gold
}
def higher_is_better(self):
return {
"acc": True
}
def aggregation(self):
return {
"acc": mean
}
class WNLI(PromptSourceTask):
class WNLI(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "wnli" DATASET_NAME = "wnli"
...@@ -301,49 +184,13 @@ class WNLI(Task): ...@@ -301,49 +184,13 @@ class WNLI(Task):
return False return False
def training_docs(self): def training_docs(self):
if self._training_docs is None: return self.dataset["train"]
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
return "{}\nQuestion: {} True or False?\nAnswer:".format(
doc["sentence1"],
doc["sentence2"],
)
def doc_to_target(self, doc):
# True = entailment
# False = not_entailment
return " {}".format({0: "False", 1: "True"}[doc["label"]])
def construct_requests(self, doc, ctx):
ll_true, _ = rf.loglikelihood(ctx, " True")
ll_false, _ = rf.loglikelihood(ctx, " False")
return ll_true, ll_false
def process_results(self, doc, results):
ll_true, ll_false = results
pred = ll_true > ll_false
gold = doc["label"]
return {
"acc": pred == gold
}
def higher_is_better(self):
return {
"acc": True
}
def aggregation(self):
return {
"acc": mean
}
class RTE(Task): class RTE(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "rte" DATASET_NAME = "rte"
...@@ -365,45 +212,17 @@ class RTE(Task): ...@@ -365,45 +212,17 @@ class RTE(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
return "{}\nQuestion: {} True or False?\nAnswer:".format(
doc["sentence1"],
doc["sentence2"],
)
def doc_to_target(self, doc):
# 0 = entailment
# 1 = not_entailment
return " {}".format({0: "True", 1: "False"}[doc["label"]])
def construct_requests(self, doc, ctx):
ll_true, _ = rf.loglikelihood(ctx, " True")
ll_false, _ = rf.loglikelihood(ctx, " False")
return ll_true, ll_false
def process_results(self, doc, results):
ll_true, ll_false = results
pred = ll_false > ll_true
gold = doc["label"]
return {
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
# Similarity and Paraphrase Tasks # Similarity and Paraphrase Tasks
class MRPC(Task): class MRPC(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "mrpc" DATASET_NAME = "mrpc"
...@@ -417,6 +236,9 @@ class MRPC(Task): ...@@ -417,6 +236,9 @@ class MRPC(Task):
def has_test_docs(self): def has_test_docs(self):
return False return False
def stopping_criteria(self):
return "\n"
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.dataset["train"]) self._training_docs = list(self.dataset["train"])
...@@ -425,43 +247,8 @@ class MRPC(Task): ...@@ -425,43 +247,8 @@ class MRPC(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
return "Sentence 1: {}\nSentence 2: {}\nQuestion: Do both sentences mean the same thing?\nAnswer:".format(
general_detokenize(doc["sentence1"]),
general_detokenize(doc["sentence2"]),
)
def doc_to_target(self, doc):
return " {}".format(yesno(doc["label"]))
def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, " no")
return ll_yes, ll_no
def process_results(self, doc, results):
ll_yes, ll_no = results
gold = doc["label"]
pred = ll_yes > ll_no
return {
"acc": pred == gold,
"f1": (gold, pred),
}
def higher_is_better(self): class QQP(PromptSourceTask):
return {
"acc": True,
"f1": True
}
def aggregation(self):
return {
"acc": mean,
"f1": f1_score
}
class QQP(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "qqp" DATASET_NAME = "qqp"
...@@ -483,41 +270,6 @@ class QQP(Task): ...@@ -483,41 +270,6 @@ class QQP(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
return "Question 1: {}\nQuestion 2: {}\nQuestion: Do both questions ask the same thing?\nAnswer:".format(
doc["question1"],
doc["question2"],
)
def doc_to_target(self, doc):
return " {}".format(yesno(doc["label"]))
def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, " no")
return ll_yes, ll_no
def process_results(self, doc, results):
ll_yes, ll_no = results
gold = doc["label"]
pred = ll_yes > ll_no
return {
"acc": pred == gold,
"f1": (gold, pred),
}
def higher_is_better(self):
return {
"acc": True,
"f1": True
}
def aggregation(self):
return {
"acc": mean,
"f1": f1_score
}
class STSB(Task): class STSB(Task):
VERSION = 0 VERSION = 0
...@@ -554,22 +306,22 @@ class STSB(Task): ...@@ -554,22 +306,22 @@ class STSB(Task):
return " {}".format(doc["label"]) return " {}".format(doc["label"])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str :param ctx: str
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
:param doc: :param doc:
...@@ -578,22 +330,22 @@ class STSB(Task): ...@@ -578,22 +330,22 @@ class STSB(Task):
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def aggregation(self): def aggregation(self):
""" """
:returns: {str: [float] -> float} :returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def higher_is_better(self): def higher_is_better(self):
""" """
:returns: {str: bool} :returns: {str: bool}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
"""
Right for the Wrong Reasons: Diagnosing Syntactic Heuristics in Natural Language Inference
https://arxiv.org/abs/1902.01007
A controlled evaluation set called HANS (Heuristic Analysis for NLI Systems),
which contains many examples where the heuristics fail.
Homepage: https://github.com/tommccoy1/hans
"""
from lm_eval.base import PromptSourceTask
_CITATION = """
@inproceedings{mccoy-etal-2019-right,
title = "Right for the Wrong Reasons: Diagnosing Syntactic Heuristics in Natural Language Inference",
author = "McCoy, Tom and
Pavlick, Ellie and
Linzen, Tal",
booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics",
month = jul,
year = "2019",
address = "Florence, Italy",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/P19-1334",
doi = "10.18653/v1/P19-1334",
pages = "3428--3448",
abstract = "A machine learning system can score well on a given test set by relying on heuristics that are effective for frequent example types but break down in more challenging cases. We study this issue within natural language inference (NLI), the task of determining whether one sentence entails another. We hypothesize that statistical NLI models may adopt three fallible syntactic heuristics: the lexical overlap heuristic, the subsequence heuristic, and the constituent heuristic. To determine whether models have adopted these heuristics, we introduce a controlled evaluation set called HANS (Heuristic Analysis for NLI Systems), which contains many examples where the heuristics fail. We find that models trained on MNLI, including BERT, a state-of-the-art model, perform very poorly on HANS, suggesting that they have indeed adopted these heuristics. We conclude that there is substantial room for improvement in NLI systems, and that the HANS dataset can motivate and measure progress in this area.",
}
"""
class HANS(PromptSourceTask):
VERSION = 0
DATASET_PATH = "hans"
DATASET_NAME = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
if self.has_training_docs():
# We cache training documents in `self._training_docs` for faster
# few-shot processing. If the data is too large to fit in memory,
# return the training data as a generator instead of a list.
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
...@@ -12,7 +12,7 @@ Homepage: https://www.cs.cmu.edu/~glai1/data/race/ ...@@ -12,7 +12,7 @@ Homepage: https://www.cs.cmu.edu/~glai1/data/race/
import collections import collections
import datasets import datasets
import numpy as np import numpy as np
from lm_eval.base import rf, Task from lm_eval.base import PromptSourceTask, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
...@@ -34,13 +34,13 @@ class each: ...@@ -34,13 +34,13 @@ class each:
return list(map(self.f, other)) return list(map(self.f, other))
class RACE(Task): class RACE(PromptSourceTask):
VERSION = 1 VERSION = 1
DATASET_PATH = "race" DATASET_PATH = "race"
DATASET_NAME = "high" DATASET_NAME = "high"
cache = {} cache = {}
letter_to_num = {'A': 0, 'B': 1, 'C': 2, 'D': 3} letter_to_num = {"A": 0, "B": 1, "C": 2, "D": 3}
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -51,83 +51,92 @@ class RACE(Task): ...@@ -51,83 +51,92 @@ class RACE(Task):
def has_test_docs(self): def has_test_docs(self):
return True return True
def _collate_data(self, set): # def _collate_data(self, set):
if set in self.cache: # if set in self.cache:
return self.cache[set] # return self.cache[set]
# One big issue with HF's implementation of this dataset: it makes a # # One big issue with HF's implementation of this dataset: it makes a
# separate document for each question; meanwhile, in the GPT3 paper it # # separate document for each question; meanwhile, in the GPT3 paper it
# is shown that one document is made per passage. # # is shown that one document is made per passage.
r = collections.defaultdict(list) # r = collections.defaultdict(list)
for item in datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)[set]: # for item in datasets.load_dataset(
r[item['article']].append(item) # path=self.DATASET_PATH, name=self.DATASET_NAME
# )[set]:
res = list(r.values() >> each(lambda x: { # r[item["article"]].append(item)
'article': x[0]['article'],
'problems': x >> each(lambda y: { # res = list(
'question': y['question'], # r.values()
'answer': y['answer'], # >> each(
'options': y['options'], # lambda x: {
}) # "article": x[0]["article"],
})) # "problems": x
# >> each(
self.cache[set] = res # lambda y: {
return res # "question": y["question"],
# "answer": y["answer"],
# "options": y["options"],
# }
# ),
# }
# )
# )
# self.cache[set] = res
# return res
def training_docs(self): def training_docs(self):
return self._collate_data("train") return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
return self._collate_data("validation") return self.dataset["validation"]
def test_docs(self): def test_docs(self):
return self._collate_data("test") return self.dataset["test"]
@classmethod @classmethod
def get_answer_option(cls, problem): def get_answer_option(cls, problem):
answer = cls.letter_to_num[problem['answer']] answer = cls.letter_to_num[problem["answer"]]
return problem['options'][answer] return problem["options"][answer]
@classmethod @classmethod
def last_problem(cls, doc): def last_problem(cls, doc):
return doc['problems'][-1] return doc["problems"][-1]
def doc_to_text(self, doc): # def doc_to_text(self, doc):
text = 'Article: ' + doc['article'] + '\n\n' # text = 'Article: ' + doc['article'] + '\n\n'
for problem in doc['problems'][:-1]: # for problem in doc['problems'][:-1]:
if problem['question'][-6:] == ' _ .': # if problem['question'][-6:] == ' _ .':
text += problem['question'][-5:] + self.get_answer_option(problem) + '\n' # text += problem['question'][-5:] + self.get_answer_option(problem) + '\n'
else: # else:
question = 'Question: ' + problem['question'] + '\n' # question = 'Question: ' + problem['question'] + '\n'
answer = 'Answer: ' + self.get_answer_option(problem) + '\n' # answer = 'Answer: ' + self.get_answer_option(problem) + '\n'
text += question + answer # text += question + answer
text += self.last_problem(doc)['question'] # text += self.last_problem(doc)['question']
return text # return text
def doc_to_target(self, doc): # def doc_to_target(self, doc):
return " " + self.get_answer_option(self.last_problem(doc)) # return " " + self.get_answer_option(self.last_problem(doc))
def construct_requests(self, doc, ctx): # def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of # """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. # Requests which will be sent to the LM.
:param doc: # :param doc:
The document as returned from training_docs, validation_docs, or test_docs. # The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str # :param ctx: str
The context string, generated by fewshot_context. This includes the natural # The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question # language description, as well as the few shot examples, and the question
part of the document for `doc`. # part of the document for `doc`.
""" # """
problem = self.last_problem(doc) # problem = self.last_problem(doc)
ll_choices = [ # ll_choices = [
rf.loglikelihood(ctx, " " + problem['options'][i])[0] # rf.loglikelihood(ctx, " " + problem["options"][i])[0] for i in range(4)
for i in range(4) # ]
] # return ll_choices
return ll_choices
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
:param doc: :param doc:
...@@ -135,28 +144,24 @@ class RACE(Task): ...@@ -135,28 +144,24 @@ class RACE(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
gold = self.letter_to_num[self.last_problem(doc)['answer']] #
gold = self.letter_to_num[self.doc_to_target(doc)]
# gold = self.letter_to_num[self.last_problem(doc)["answer"]]
pred = np.argmax(results) pred = np.argmax(results)
return { return {"acc": int(pred == gold)}
"acc": int(pred == gold)
}
def aggregation(self): def aggregation(self):
""" """
:returns: {str: [float] -> float} :returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"acc": mean}
"acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
:returns: {str: bool} :returns: {str: bool}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"acc": True}
"acc": True
}
...@@ -12,7 +12,7 @@ TODO: WSC requires free-form generation. ...@@ -12,7 +12,7 @@ TODO: WSC requires free-form generation.
import numpy as np import numpy as np
import sklearn import sklearn
import transformers.data.metrics.squad_metrics as squad_metrics import transformers.data.metrics.squad_metrics as squad_metrics
from lm_eval.base import rf, Task from lm_eval.base import rf, PromptSourceTask
from lm_eval.metrics import mean, acc_all, metric_max_over_ground_truths, yesno from lm_eval.metrics import mean, acc_all, metric_max_over_ground_truths, yesno
from lm_eval.utils import general_detokenize from lm_eval.utils import general_detokenize
...@@ -32,7 +32,7 @@ _CITATION = """ ...@@ -32,7 +32,7 @@ _CITATION = """
""" """
class BoolQ(Task): class BoolQ(PromptSourceTask):
VERSION = 1 VERSION = 1
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "boolq" DATASET_NAME = "boolq"
...@@ -54,41 +54,8 @@ class BoolQ(Task): ...@@ -54,41 +54,8 @@ class BoolQ(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
return f"{doc['passage']}\nQuestion: {doc['question']}?\nAnswer:"
def doc_to_target(self, doc):
return " " + yesno(doc['label'])
def construct_requests(self, doc, ctx): class CommitmentBank(PromptSourceTask):
ll_yes, _ = rf.loglikelihood(ctx, ' yes')
ll_no, _ = rf.loglikelihood(ctx, ' no')
return ll_yes, ll_no
def process_results(self, doc, results):
ll_yes, ll_no = results
gold = doc["label"]
acc = 1. if (ll_yes > ll_no) == gold else 0.
return {
"acc": acc
}
def higher_is_better(self):
return {
"acc": True
}
def aggregation(self):
return {
"acc": mean
}
class CommitmentBank(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "cb" DATASET_NAME = "cb"
...@@ -110,40 +77,15 @@ class CommitmentBank(Task): ...@@ -110,40 +77,15 @@ class CommitmentBank(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
return "{}\nQuestion: {}. True, False or Neither?\nAnswer:".format(
doc["premise"],
doc["hypothesis"],
)
def doc_to_target(self, doc):
# True = entailment
# False = contradiction
# Neither = neutral
return " {}".format({0: "True", 1: "False", 2: "Neither"}[doc["label"]])
def construct_requests(self, doc, ctx):
ll_true, _ = rf.loglikelihood(ctx, ' True')
ll_false, _ = rf.loglikelihood(ctx, ' False')
ll_neither, _ = rf.loglikelihood(ctx, ' Neither')
return ll_true, ll_false, ll_neither
def process_results(self, doc, results): def process_results(self, doc, results):
gold = doc["label"] gold = doc["label"]
pred = np.argmax(results) pred = np.argmax(results)
acc = 1. if pred == gold else 0. acc = 1.0 if pred == gold else 0.0
return {"acc": acc, "f1": (pred, gold)}
return {
"acc": acc,
"f1": (pred, gold)
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True, "f1": True}
"acc": True,
"f1": True
}
@classmethod @classmethod
def cb_multi_fi(cls, items): def cb_multi_fi(cls, items):
...@@ -155,7 +97,7 @@ class CommitmentBank(Task): ...@@ -155,7 +97,7 @@ class CommitmentBank(Task):
f13 = sklearn.metrics.f1_score(y_true=golds == 2, y_pred=preds == 2) f13 = sklearn.metrics.f1_score(y_true=golds == 2, y_pred=preds == 2)
avg_f1 = mean([f11, f12, f13]) avg_f1 = mean([f11, f12, f13])
return avg_f1 return avg_f1
def aggregation(self): def aggregation(self):
return { return {
"acc": mean, "acc": mean,
...@@ -163,7 +105,7 @@ class CommitmentBank(Task): ...@@ -163,7 +105,7 @@ class CommitmentBank(Task):
} }
class Copa(Task): class Copa(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "copa" DATASET_NAME = "copa"
...@@ -185,53 +127,25 @@ class Copa(Task): ...@@ -185,53 +127,25 @@ class Copa(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
# Drop the period
connector = {
"cause": "because",
"effect": "therefore",
}[doc["question"]]
return doc["premise"].strip()[:-1] + f" {connector}"
def doc_to_target(self, doc):
correct_choice = doc["choice1"] if doc["label"] == 0 else doc["choice2"]
# Connect the sentences
return " " + self.convert_choice(correct_choice)
def construct_requests(self, doc, ctx):
choice1 = " " + self.convert_choice(doc["choice1"])
choice2 = " " + self.convert_choice(doc["choice2"])
ll_choice1, _ = rf.loglikelihood(ctx, choice1)
ll_choice2, _ = rf.loglikelihood(ctx, choice2)
return ll_choice1, ll_choice2
def process_results(self, doc, results): def process_results(self, doc, results):
gold = doc["label"] gold = doc["label"]
pred = np.argmax(results) pred = np.argmax(results)
acc = 1. if pred == gold else 0. acc = 1.0 if pred == gold else 0.0
return {"acc": acc}
return {
"acc": acc
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
@staticmethod @staticmethod
def convert_choice(choice): def convert_choice(choice):
return choice[0].lower() + choice[1:] return choice[0].lower() + choice[1:]
class MultiRC(Task): class MultiRC(PromptSourceTask):
VERSION = 1 VERSION = 1
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "multirc" DATASET_NAME = "multirc"
...@@ -253,45 +167,19 @@ class MultiRC(Task): ...@@ -253,45 +167,19 @@ class MultiRC(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:"
def doc_to_target(self, doc):
return " " + self.format_answer(answer=doc["answer"], label=doc["label"])
@staticmethod
def format_answer(answer, label):
label_str = "yes" if label else "no"
return f"{answer}\nIs the answer correct? {label_str}"
def construct_requests(self, doc, ctx):
true_choice = self.format_answer(answer=doc["answer"], label=True)
false_choice = self.format_answer(answer=doc["answer"], label=False)
ll_true_choice, _ = rf.loglikelihood(ctx, f' {true_choice}')
ll_false_choice, _ = rf.loglikelihood(ctx, f' {false_choice}')
return ll_true_choice, ll_false_choice
def process_results(self, doc, results): def process_results(self, doc, results):
ll_true_choice, ll_false_choice = results ll_true_choice, ll_false_choice = results
pred = ll_true_choice > ll_false_choice pred = ll_true_choice > ll_false_choice
return { return {"acc": (pred, doc)}
"acc": (pred, doc)
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": acc_all}
"acc": acc_all
}
class ReCoRD(Task): class ReCoRD(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "record" DATASET_NAME = "record"
...@@ -311,56 +199,31 @@ class ReCoRD(Task): ...@@ -311,56 +199,31 @@ class ReCoRD(Task):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = [] self._training_docs = []
for doc in self.dataset["train"]: for doc in self.dataset["train"]:
self._training_docs.append(self._process_doc(doc)) self._training_docs.append(doc)
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
# See: training_docs # See: training_docs
for doc in self.dataset["validation"]: for doc in self.dataset["validation"]:
yield self._process_doc(doc) yield doc
@classmethod
def _process_doc(cls, doc):
return {
"passage": doc["passage"],
"query": doc["query"],
"entities": sorted(list(set(doc["entities"]))),
"answers": sorted(list(set(doc["answers"]))),
}
def doc_to_text(self, doc):
initial_text, *highlights = doc["passage"].strip().split("\n@highlight\n")
text = initial_text + "\n\n"
for highlight in highlights:
text += f" - {highlight}.\n"
return text
@classmethod
def format_answer(cls, query, entity):
return f' - {query}'.replace("@placeholder", entity)
def doc_to_target(self, doc):
# We only output the first correct entity in a doc
return self.format_answer(query=doc["query"], entity=doc["answers"][0])
def construct_requests(self, doc, ctx):
requests = [
rf.loglikelihood(ctx, self.format_answer(query=doc["query"], entity=entity))
for entity in doc["entities"]
]
return requests
def process_results(self, doc, results): def process_results(self, doc, results):
# ReCoRD's evaluation is actually deceptively simple: # ReCoRD's evaluation is actually deceptively simple:
# - Pick the maximum likelihood prediction entity # - Pick the maximum likelihood prediction entity
# - Evaluate the accuracy and token F1 PER EXAMPLE # - Evaluate the accuracy and token F1 PER EXAMPLE
# - Average over all examples # - Average over all examples
# TODO (jon-tow): Look at result
max_idx = np.argmax(np.array([result[0] for result in results])) max_idx = np.argmax(np.array([result[0] for result in results]))
prediction = doc["entities"][max_idx] prediction = doc["entities"][max_idx]
gold_label_set = doc["answers"] gold_label_set = doc["answers"]
f1 = metric_max_over_ground_truths(squad_metrics.compute_f1, prediction, gold_label_set) f1 = metric_max_over_ground_truths(
em = metric_max_over_ground_truths(squad_metrics.compute_exact, prediction, gold_label_set) squad_metrics.compute_f1, prediction, gold_label_set
)
em = metric_max_over_ground_truths(
squad_metrics.compute_exact, prediction, gold_label_set
)
return { return {
"f1": f1, "f1": f1,
...@@ -380,7 +243,7 @@ class ReCoRD(Task): ...@@ -380,7 +243,7 @@ class ReCoRD(Task):
} }
class WordsInContext(Task): class WordsInContext(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "wic" DATASET_NAME = "wic"
...@@ -402,50 +265,19 @@ class WordsInContext(Task): ...@@ -402,50 +265,19 @@ class WordsInContext(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
return "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the" \
" two sentences above?\nAnswer:".format(
doc["sentence1"],
doc["sentence2"],
doc["sentence1"][doc["start1"]:doc["end1"]],
)
def doc_to_target(self, doc):
return " {}".format({0: "no", 1: "yes"}[doc["label"]])
def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, ' yes')
ll_no, _ = rf.loglikelihood(ctx, ' no')
return ll_yes, ll_no
def process_results(self, doc, results):
ll_yes, ll_no = results
gold = doc["label"]
acc = 1. if (ll_yes > ll_no) == gold else 0.
return {
"acc": acc
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
class SGWinogradSchemaChallenge(Task): class SGWinogradSchemaChallenge(PromptSourceTask):
VERSION = 0 VERSION = 0
# Note: This implementation differs from Fig G.32 because this is the SuperGLUE, # Note: This implementation differs from Fig G.32 because this is the SuperGLUE,
# binary version of the task. # binary version of the task.
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "wsc" DATASET_NAME = "wsc.fixed"
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -461,56 +293,15 @@ class SGWinogradSchemaChallenge(Task): ...@@ -461,56 +293,15 @@ class SGWinogradSchemaChallenge(Task):
if self._training_docs is None: if self._training_docs is None:
# GPT-3 Paper's format only uses positive examples for fewshot "training" # GPT-3 Paper's format only uses positive examples for fewshot "training"
self._training_docs = [ self._training_docs = [
doc for doc in doc for doc in self.dataset["train"] if doc["label"]
self.dataset["train"]
if doc["label"]
] ]
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc):
raw_passage = doc["text"]
# NOTE: HuggingFace span indices are word-based not character-based.
pre = " ".join(raw_passage.split()[:doc["span2_index"]])
post = raw_passage[len(pre) + len(doc["span2_text"]) + 1:]
passage = general_detokenize(pre + " *{}*".format(doc['span2_text']) + post)
noun = doc["span1_text"]
pronoun = doc["span2_text"]
text = (
f"Passage: {passage}\n"
+ f"Question: In the passage above, does the pronoun \"*{pronoun}*\" refer to \"*{noun}*\"?\n"
+ "Answer:"
)
return text
def doc_to_target(self, doc):
return " " + yesno(doc['label'])
def construct_requests(self, doc, ctx):
ll_yes, _ = rf.loglikelihood(ctx, ' yes')
ll_no, _ = rf.loglikelihood(ctx, ' no')
return ll_yes, ll_no
def process_results(self, doc, results):
ll_yes, ll_no = results
gold = doc["label"]
acc = 1. if (ll_yes > ll_no) == gold else 0.
return {
"acc": acc
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
...@@ -146,6 +146,19 @@ class Reorderer: ...@@ -146,6 +146,19 @@ class Reorderer:
return res return res
def flatten(d, parent_key='', sep='_'):
# From: https://stackoverflow.com/a/6027615
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.MutableMapping):
items.extend(flatten(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
def positional_deprecated(fn): def positional_deprecated(fn):
""" """
A decorator to nudge users into passing only keyword args (`kwargs`) to the A decorator to nudge users into passing only keyword args (`kwargs`) to the
......
...@@ -30,7 +30,7 @@ def main(): ...@@ -30,7 +30,7 @@ def main():
task_names = tasks.ALL_TASKS task_names = tasks.ALL_TASKS
else: else:
task_names = args.tasks.split(",") task_names = args.tasks.split(",")
task_dict = tasks.get_task_dict(task_names) task_dict = tasks.get_task_dict_promptsource(task_names)
description_dict = {} description_dict = {}
if args.description_dict_path: if args.description_dict_path:
......
...@@ -18,8 +18,12 @@ setuptools.setup( ...@@ -18,8 +18,12 @@ setuptools.setup(
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
], ],
python_requires='>=3.6', python_requires=">=3.6",
install_requires=[ install_requires=[
"promptsource @ git+https://github.com/bigscience-workshop/promptsource@eval-hackathon",
"wrapt",
"nltk",
"jinja2",
"black", "black",
"datasets==2.0.0", "datasets==2.0.0",
"click>=7.1", "click>=7.1",
...@@ -42,9 +46,9 @@ setuptools.setup( ...@@ -42,9 +46,9 @@ setuptools.setup(
"openai==0.6.4", "openai==0.6.4",
"jieba==0.42.1", "jieba==0.42.1",
"nagisa==0.2.7", "nagisa==0.2.7",
"bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt" "bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt",
], ],
dependency_links=[ dependency_links=[
"https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt", "https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt",
] ],
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment