Commit 88745155 authored by cjlovering's avatar cjlovering
Browse files

Initial integration

parent 6caa0afd
import abc
from typing import Iterable
import numpy as np
import random
import re
......@@ -118,7 +119,6 @@ class LM(abc.ABC):
class BaseLM(LM):
@property
@abstractmethod
def eot_token_id(self):
......@@ -145,13 +145,16 @@ class BaseLM(LM):
pass
@abstractmethod
def tok_encode(self, string: str): pass
def tok_encode(self, string: str):
pass
@abstractmethod
def tok_decode(self, tokens: Iterable[int]): pass
def tok_decode(self, tokens: Iterable[int]):
pass
@abstractmethod
def _model_generate(self, context, max_length, eos_token_id): pass
def _model_generate(self, context, max_length, eos_token_id):
pass
@abstractmethod
def _model_call(self, inps):
......@@ -187,19 +190,26 @@ class BaseLM(LM):
# TODO: automatic batch size detection for vectorization
loglikelihoods = []
for string, in tqdm(requests):
rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
for (string,) in tqdm(requests):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
)))
),
)
)
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
# that
string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True)
string_nll = self._loglikelihood_tokens(
rolling_token_windows, disable_tqdm=True
)
# discard is_greedy
string_nll = [x[0] for x in string_nll]
......@@ -226,7 +236,9 @@ class BaseLM(LM):
# TODO: automatic (variable) batch size detection for vectorization
reord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
for chunk in utils.chunks(
tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size
):
inps = []
cont_toks_list = []
inplens = []
......@@ -252,44 +264,60 @@ class BaseLM(LM):
# when too long to fit in context, truncate from the left
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length+1):][:-1],
dtype=torch.long
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long,
).to(self.device)
inplen, = inp.shape
(inplen,) = inp.shape
cont = continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length = padding_length if padding_length is not None else inplen
padding_length = (
padding_length if padding_length is not None else inplen
)
# pad length from seq to padding_length
inp = torch.cat([
inp = torch.cat(
[
inp, # [seq]
torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq]
], dim=0)
torch.zeros(padding_length - inplen, dtype=torch.long).to(
inp.device
), # [padding_length - seq]
],
dim=0,
)
inps.append(inp.unsqueeze(0)) # [1, padding_length]
cont_toks_list.append(cont)
inplens.append(inplen)
batched_inps = torch.cat(inps, dim=0) # [batch, padding_length
multi_logits = F.log_softmax(self._model_call(batched_inps), dim=-1).cpu() # [batch, padding_length, vocab]
multi_logits = F.log_softmax(
self._model_call(batched_inps), dim=-1
).cpu() # [batch, padding_length, vocab]
for (cache_key, _, _), logits, inp, inplen, cont_toks \
in zip(chunk, multi_logits, inps, inplens, cont_toks_list):
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
chunk, multi_logits, inps, inplens, cont_toks_list
):
# Slice to original seq length
contlen = len(cont_toks)
logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab]
logits = logits[inplen - contlen : inplen].unsqueeze(
0
) # [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1)
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0) # [1, seq]
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
0
) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
-1
) # [1, seq]
# Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal))
......@@ -319,13 +347,17 @@ class BaseLM(LM):
if isinstance(until, str):
until = [until]
primary_until, = self.tok_encode(until[0])
(primary_until,) = self.tok_encode(until[0])
context_enc = torch.tensor([self.tok_encode(context)[self.max_gen_toks - self.max_length:]]).to(self.device)
context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device)
cont = self._model_generate(context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until)
cont = self._model_generate(
context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until
)
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1]:])
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
for term in until:
s = s.split(term)[0]
......@@ -383,7 +415,7 @@ class Task(abc.ABC):
self._fewshot_docs = None
def download(self, data_dir=None, cache_dir=None, download_mode=None):
""" Downloads and returns the task dataset.
"""Downloads and returns the task dataset.
Override this method to download the dataset from a custom API.
:param data_dir: str
......@@ -412,7 +444,7 @@ class Task(abc.ABC):
name=self.DATASET_NAME,
data_dir=data_dir,
cache_dir=cache_dir,
download_mode=download_mode
download_mode=download_mode,
)
@abstractmethod
......@@ -478,7 +510,7 @@ class Task(abc.ABC):
@abstractmethod
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
......@@ -523,15 +555,19 @@ class Task(abc.ABC):
def fewshot_description(self):
import warnings
warnings.warn(
"`fewshot_description` will be removed in futures versions. Pass "
"any custom descriptions to the `evaluate` function instead.",
DeprecationWarning)
DeprecationWarning,
)
return ""
@utils.positional_deprecated
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
""" Returns a fewshot context string that is made up of a prepended description
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
......@@ -548,7 +584,9 @@ class Task(abc.ABC):
:returns: str
The fewshot context.
"""
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`"
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
......@@ -556,7 +594,9 @@ class Task(abc.ABC):
)
if provide_description is not None:
# nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description = description + "\n\n" if description else ""
......@@ -569,7 +609,9 @@ class Task(abc.ABC):
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs() if self.has_validation_docs() else self.test_docs()
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
......@@ -577,23 +619,90 @@ class Task(abc.ABC):
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = "\n\n".join(
[self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex]
) + "\n\n"
labeled_examples = (
"\n\n".join(
[
self.doc_to_text(doc) + self.doc_to_target(doc)
for doc in fewshotex
]
)
+ "\n\n"
)
example = self.doc_to_text(doc)
return description + labeled_examples + example
class MultipleChoiceTask(Task):
class PromptSourceTask(Task):
def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=None):
super().__init__(data_dir, cache_dir, download_mode)
self.prompt = prompt
def doc_to_target(self, doc):
_, target = prompt.apply(doc)
return f" {target}"
def doc_to_text(self, doc):
text, _ = prompt.apply(doc)
return text
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
_requests = []
if self.prompt.metadata.choices_in_prompt:
for answer_choice in prompt.get_fixed_answer_choices_list():
ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}")
_requests.append(ll_answer_choice)
else:
# TODO(Albert): What is the stop symbol? Is it model specific?
ll_greedy, _ = rf.greedy_until(ctx, ["\nQ:"])
_requests.append(ll_greedy)
return _requests
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
raise NotImplementedError(
"Implement process results using the `prompt.metadata.metrics`. See below."
)
if self.prompt.metadata.choices_in_prompt:
for result, answer_choice in zip(
prompt.get_fixed_answer_choices_list(), results
):
pass
else:
continuation = results
# Map metric name to HF metric.
# TODO(Albert): What is Other?
metric_names = prompt.metadata.metrics
class MultipleChoiceTask(Task):
def doc_to_target(self, doc):
return " " + doc['choices'][doc['gold']]
return " " + doc["choices"][doc["gold"]]
def construct_requests(self, doc, ctx):
lls = [
rf.loglikelihood(ctx, " {}".format(choice))[0]
for choice in doc['choices']
rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
]
return lls
......@@ -601,9 +710,9 @@ class MultipleChoiceTask(Task):
def process_results(self, doc, results):
gold = doc["gold"]
acc = 1. if np.argmax(results) == gold else 0.
acc = 1.0 if np.argmax(results) == gold else 0.0
completion_len = np.array([float(len(i)) for i in doc["choices"]])
acc_norm = 1. if np.argmax(results / completion_len) == gold else 0.
acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
return {
"acc": acc,
......@@ -624,7 +733,6 @@ class MultipleChoiceTask(Task):
class PerplexityTask(Task, abc.ABC):
def has_training_docs(self):
return False
......@@ -632,9 +740,15 @@ class PerplexityTask(Task, abc.ABC):
assert k == 0
return []
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, "The number of fewshot examples must be 0 for perplexity tasks."
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`."
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
assert (
num_fewshot == 0
), "The number of fewshot examples must be 0 for perplexity tasks."
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`."
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
......@@ -642,7 +756,9 @@ class PerplexityTask(Task, abc.ABC):
)
if provide_description is not None:
# nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
return ""
......@@ -665,7 +781,7 @@ class PerplexityTask(Task, abc.ABC):
return req
def process_results(self, doc, results):
loglikelihood, = results
(loglikelihood,) = results
words = self.count_words(doc)
bytes_ = self.count_bytes(doc)
return {
......@@ -687,13 +803,13 @@ class PerplexityTask(Task, abc.ABC):
@classmethod
def count_words(cls, doc):
""" Downstream tasks with custom word boundaries should override this! """
"""Downstream tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc))
def hash_args(attr, args):
dat = json.dumps([attr] + list(args))
return hashlib.sha256(dat.encode('utf-8')).hexdigest()
return hashlib.sha256(dat.encode("utf-8")).hexdigest()
class CacheHook:
......@@ -764,6 +880,7 @@ class CachingLM:
self.dbdict.commit()
return res
return fn
def get_cache_hook(self):
......@@ -771,16 +888,18 @@ class CachingLM:
REQUEST_RETURN_LENGTHS = {
'loglikelihood': 2,
'greedy_until': None,
'loglikelihood_rolling': None,
"loglikelihood": 2,
"greedy_until": None,
"loglikelihood_rolling": None,
}
class Request:
def __init__(self, request_type, args, index=None):
if request_type not in REQUEST_RETURN_LENGTHS.keys():
raise NotImplementedError('The request type {} is not implemented!'.format(request_type))
raise NotImplementedError(
"The request type {} is not implemented!".format(request_type)
)
self.request_type = request_type
self.args = args
......@@ -788,17 +907,21 @@ class Request:
def __iter__(self):
if REQUEST_RETURN_LENGTHS[self.request_type] is None:
raise IndexError('This request type does not return multiple arguments!')
raise IndexError("This request type does not return multiple arguments!")
for i in range(REQUEST_RETURN_LENGTHS[self.request_type]):
yield Request(self.request_type, self.args, i)
def __getitem__(self, i):
if REQUEST_RETURN_LENGTHS[self.request_type] is None:
raise IndexError('This request type does not return multiple arguments!')
raise IndexError("This request type does not return multiple arguments!")
return Request(self.request_type, self.args, i)
def __eq__(self, other):
return self.request_type == other.request_type and self.args == other.args and self.index == other.index
return (
self.request_type == other.request_type
and self.args == other.args
and self.index == other.index
)
def __repr__(self):
return f"Req_{self.request_type}{self.args}[{self.index}]\n"
......@@ -808,6 +931,7 @@ class RequestFactory:
def __getattr__(self, attr):
def fn(*args):
return Request(attr, args)
return fn
......
......@@ -6,15 +6,27 @@ import lm_eval.metrics
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
import promptsource
import numpy as np
from promptsource.templates import DatasetTemplates
from lm_eval.utils import positional_deprecated, run_task_tests
@positional_deprecated
def simple_evaluate(model, model_args=None, tasks=[],
num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000,
description_dict=None, check_integrity=False):
def simple_evaluate(
model,
model_args=None,
tasks=[],
num_fewshot=0,
batch_size=None,
device=None,
no_cache=False,
limit=None,
bootstrap_iters=100000,
description_dict=None,
check_integrity=False,
):
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
......@@ -49,20 +61,26 @@ def simple_evaluate(model, model_args=None, tasks=[],
assert tasks != [], "No tasks specified"
if isinstance(model, str):
if model_args is None: model_args = ""
lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, {
'batch_size': batch_size, 'device': device
})
if model_args is None:
model_args = ""
lm = lm_eval.models.get_model(model).create_from_arg_string(
model_args, {"batch_size": batch_size, "device": device}
)
else:
assert isinstance(model, lm_eval.base.LM)
lm = model
if not no_cache:
lm = lm_eval.base.CachingLM(
lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db'
lm,
"lm_cache/"
+ model
+ "_"
+ model_args.replace("=", "-").replace(",", "_").replace("/", "-")
+ ".db",
)
task_dict = lm_eval.tasks.get_task_dict(tasks)
task_dict = lm_eval.tasks.get_task_dict_promptsource(tasks)
if check_integrity:
run_task_tests(task_list=tasks)
......@@ -72,7 +90,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
task_dict=task_dict,
num_fewshot=num_fewshot,
limit=limit,
description_dict=description_dict
description_dict=description_dict,
)
# add info about the model and few shot config
......@@ -85,14 +103,22 @@ def simple_evaluate(model, model_args=None, tasks=[],
"no_cache": no_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters,
"description_dict": description_dict
"description_dict": description_dict,
}
return results
@positional_deprecated
def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None):
def evaluate(
lm,
task_dict,
provide_description=None,
num_fewshot=0,
limit=None,
bootstrap_iters=100000,
description_dict=None,
):
"""Instantiate and evaluate a model on a list of tasks.
:param lm: obj
......@@ -118,12 +144,14 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
assert not provide_description # not implemented.
if provide_description is not None:
# nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
task_dict_items = [
(name, task)
for name, task in task_dict.items()
if(task.has_validation_docs() or task.has_test_docs())
if (task.has_validation_docs() or task.has_test_docs())
]
results = collections.defaultdict(dict)
......@@ -158,15 +186,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
rnd.seed(42)
rnd.shuffle(task_docs)
description = description_dict[task_name] if description_dict and task_name in description_dict else ""
description = (
description_dict[task_name]
if description_dict and task_name in description_dict
else ""
)
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
rnd=rnd,
description=description
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
)
reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)):
......@@ -189,7 +218,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
print("Running", reqtype, "requests")
resps = getattr(lm, reqtype)([req.args for req in reqs])
resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]
resps = [
x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
]
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp))
......@@ -208,24 +239,28 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
for metric, value in metrics.items():
vals[(task_name, metric)].append(value)
task_name, prompt_name = task_name.split("+")
results[task_name]["task_name"] = task_name
results[task_name]["prompt_name"] = prompt_name
# aggregate results
for (task_name, metric), items in vals.items():
task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
stderr = lm_eval.metrics.stderr_for_metric(
metric=task.aggregation()[metric],
bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters,
bootstrap_iters=min(bootstrap_iters, 1000)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
)
if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items)
return {
"results": dict(results),
"versions": dict(versions)
}
return {"results": dict(results), "versions": dict(versions)}
def make_table(result_dict):
......@@ -247,9 +282,9 @@ def make_table(result_dict):
if m + "_stderr" in dic:
se = dic[m + "_stderr"]
values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se])
values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
else:
values.append([k, version, m, '%.4f' % v, '', ''])
values.append([k, version, m, "%.4f" % v, "", ""])
k = ""
version = ""
md_writer.value_matrix = values
......
from promptsource.templates import DatasetTemplates
from pprint import pprint
from typing import List, Union
......@@ -58,8 +60,8 @@ from . import storycloze
# 6 total
gpt3_translation_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian
"wmt14": ["en-fr", "fr-en"], # French
"wmt16": ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian
}
......@@ -67,7 +69,7 @@ gpt3_translation_benchmarks = {
selected_translation_benchmarks = {
**gpt3_translation_benchmarks,
"wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
"iwslt17": ['en-ar', 'ar-en'] # Arabic
"iwslt17": ["en-ar", "ar-en"], # Arabic
}
# 319 total
......@@ -91,7 +93,7 @@ TASK_REGISTRY = {
"rte": glue.RTE,
"qnli": glue.QNLI,
"qqp": glue.QQP,
#"stsb": glue.STSB, # not implemented yet
# "stsb": glue.STSB, # not implemented yet
"sst": glue.SST,
"wnli": glue.WNLI,
# SuperGLUE
......@@ -102,34 +104,26 @@ TASK_REGISTRY = {
"record": superglue.ReCoRD,
"wic": superglue.WordsInContext,
"wsc": superglue.SGWinogradSchemaChallenge,
# Order by benchmark/genre?
"coqa": coqa.CoQA,
"drop": drop.DROP,
"lambada": lambada.LAMBADA,
"lambada_cloze": lambada_cloze.LAMBADA_cloze,
# multilingual lambada
**lambada_multilingual.construct_tasks(),
"wikitext": wikitext.WikiText,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix
"piqa": piqa.PiQA,
"prost": prost.PROST,
"mc_taco": mc_taco.MCTACO,
# Science related
"pubmedqa" : pubmedqa.Pubmed_QA,
"sciq" : sciq.SciQ,
"pubmedqa": pubmedqa.Pubmed_QA,
"sciq": sciq.SciQ,
"qasper": qasper.QASPER,
"qa4mre_2011" : qa4mre.QA4MRE_2011,
"qa4mre_2012" : qa4mre.QA4MRE_2012,
"qa4mre_2013" : qa4mre.QA4MRE_2013,
"qa4mre_2011": qa4mre.QA4MRE_2011,
"qa4mre_2012": qa4mre.QA4MRE_2012,
"qa4mre_2013": qa4mre.QA4MRE_2013,
"triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge,
......@@ -150,21 +144,17 @@ TASK_REGISTRY = {
"anli_r1": anli.ANLIRound1,
"anli_r2": anli.ANLIRound2,
"anli_r3": anli.ANLIRound3,
"ethics_cm": hendrycks_ethics.EthicsCM,
"ethics_deontology": hendrycks_ethics.EthicsDeontology,
"ethics_justice": hendrycks_ethics.EthicsJustice,
"ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal,
"ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
"ethics_virtue": hendrycks_ethics.EthicsVirtue,
"truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
"truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
# dialogue
"mutual": mutual.MuTual,
"mutual_plus": mutual.MuTualPlus,
# math
"math_algebra": hendrycks_math.MathAlgebra,
"math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
......@@ -175,7 +165,6 @@ TASK_REGISTRY = {
"math_precalc": hendrycks_math.MathPrecalculus,
"math_asdiv": asdiv.Asdiv,
"gsm8k": gsm8k.GradeSchoolMath8K,
# arithmetic
"arithmetic_2da": arithmetic.Arithmetic2DPlus,
"arithmetic_2ds": arithmetic.Arithmetic2DMinus,
......@@ -189,22 +178,18 @@ TASK_REGISTRY = {
"arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
# hendrycksTest (57 tasks)
**hendrycks_test.create_all_tasks(),
# e.g. wmt14-fr-en
**translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks),
# chef's selection, mostly wmt20
**translation.create_tasks_from_benchmarks(selected_translation_benchmarks),
# Word Scrambling and Manipulation Tasks
"anagrams1": unscramble.Anagrams1,
"anagrams2": unscramble.Anagrams2,
"cycle_letters": unscramble.CycleLetters,
"random_insertion": unscramble.RandomInsertion,
"reversed_words": unscramble.ReversedWords,
# Pile
"pile_arxiv": pile.PileArxiv,
"pile_books3": pile.PileBooks3,
......@@ -228,7 +213,6 @@ TASK_REGISTRY = {
"pile_ubuntu-irc": pile.PileUbuntuIrc,
"pile_wikipedia": pile.PileWikipedia,
"pile_youtubesubtitles": pile.PileYoutubeSubtitles,
# BLiMP
"blimp_adjunct_island": blimp.BlimpAdjunctIsland,
"blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement,
......@@ -297,7 +281,6 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
"blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
"blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
......@@ -323,17 +306,41 @@ def get_task_name_from_object(task_object):
return name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return task_object.EVAL_HARNESS_NAME if hasattr(task_object, "EVAL_HARNESS_NAME") else type(task_object).__name__
return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]):
task_name_dict = {
task_name: get_task(task_name)()
for task_name in task_name_list if isinstance(task_name, str)
for task_name in task_name_list
if isinstance(task_name, str)
}
task_name_from_object_dict = {
get_task_name_from_object(task_object): task_object
for task_object in task_name_list if not isinstance(task_object, str)
for task_object in task_name_list
if not isinstance(task_object, str)
}
assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {**task_name_dict, **task_name_from_object_dict}
def get_task_dict_promptsource(task_name_list: List[str]):
"""Loads a task instance for each prompt written for that task."""
task_name_dict = {}
for task_name in task_name_list:
assert isinstance(task_name, str)
task_prompts = DatasetTemplates(task_name)
for prompt_name in task_prompts.all_template_names:
prompt = task_prompts[prompt_name]
# NOTE: We choose a sep that can be easily split.
task_name_dict[f"{task_name}+{prompt_name}"] = get_task(task_name)(
prompt=prompt
)
return task_name_dict
......@@ -51,44 +51,22 @@ class CoQA(Task):
def test_docs(self):
pass
def doc_to_text(self, doc):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + '\n\n'
for (q, a) in zip_longest(doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]): # omit target answer ai
question = f"Q: {q}\n\n"
answer = f"A: {a}\n\n" if a is not None else "A:"
doc_text += question + answer
return doc_text
@classmethod
def get_answers(cls, doc, turn_id):
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers = []
answer_forturn = doc["answers"]["input_text"][turn_id - 1]
answers.append(answer_forturn)
additional_answers = doc.get("additional_answers")
if additional_answers:
for key in additional_answers:
additional_answer_for_turn = additional_answers[key]["input_text"][turn_id - 1]
if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn)
return answers
@classmethod
def get_answer_choice(self, raw_text):
# Function maps answers to CoQA answer categories
# ~ 1/5 of the CoQA answers are Yes/No
# ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch)
if raw_text == "unknown":
return '0'
if squad_metrics.normalize_answer(raw_text) == "yes":
return '1'
if squad_metrics.normalize_answer(raw_text) == "no":
return '2'
return '3' # Not a yes/no question
# @classmethod
# def get_answers(cls, doc, turn_id):
# # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
# answers = []
# answer_forturn = doc["answers"]["input_text"][turn_id - 1]
# answers.append(answer_forturn)
# additional_answers = doc.get("additional_answers")
# if additional_answers:
# for key in additional_answers:
# additional_answer_for_turn = additional_answers[key]["input_text"][
# turn_id - 1
# ]
# if additional_answer_for_turn.lower() not in map(str.lower, answers):
# answers.append(additional_answer_for_turn)
# return answers
@staticmethod
def compute_scores(gold_list, pred):
......@@ -98,25 +76,23 @@ class CoQA(Task):
em_sum = 0.0
if len(gold_list) > 1:
for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1:]
gold_answers = gold_list[0:i] + gold_list[i + 1 :]
# predictions compared against (n) golds and take maximum
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
em_sum += max(
squad_metrics.compute_exact(a, pred) for a in gold_answers
)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)
return {'em': em_sum / max(1, len(gold_list)), 'f1': f1_sum / max(1, len(gold_list))}
def doc_to_target(self, doc, turnid=None):
# Default to prediction of last turn.
if turnid is None:
turnid = len(doc["questions"]["input_text"])
raw_text = doc['answers']["input_text"][turnid - 1]
return " " + raw_text
return {
"em": em_sum / max(1, len(gold_list)),
"f1": f1_sum / max(1, len(gold_list)),
}
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
......@@ -126,7 +102,7 @@ class CoQA(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
cont_request = rf.greedy_until(ctx, ['\nQ:'])
cont_request = rf.greedy_until(ctx, ["\nQ:"])
return cont_request
def process_results(self, doc, results):
......@@ -139,15 +115,18 @@ class CoQA(Task):
:param results:
The results of the requests created in construct_requests.
"""
turn_id = len(doc["questions"]["input_text"])
gold_list = self.get_answers(doc, turn_id)
pred = results[0].strip().split('\n')[0]
target = self.doc_to_target(doc).strip()
pred = results[0].strip().split("\n")[0]
# turn_id = len(doc["questions"]["input_text"])
# gold_list = self.get_answers(doc, turn_id)
scores = self.compute_scores(gold_list, pred)
# TODO: Add HF metrics mapped from promptsource metadata.
scores = self.compute_scores([target], pred)
return {
"f1": scores['f1'],
"em": scores['em'],
"f1": scores["f1"],
"em": scores["em"],
}
def higher_is_better(self):
......
......@@ -70,21 +70,26 @@ class DROP(Task):
@classmethod
def get_answers(cls, qa):
def _flatten_validated_answers(validated_answers):
""" Flattens a dict of lists of validated answers.
"""Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
vas = []
for i in range(len(validated_answers["number"])):
vas.append({
vas.append(
{
"number": validated_answers["number"][i],
"date": validated_answers["date"][i],
"spans": validated_answers["spans"][i],
})
}
)
return vas
answers = []
answers_set = set()
candidates = [qa["answer"]] + _flatten_validated_answers(qa["validated_answers"])
candidates = [qa["answer"]] + _flatten_validated_answers(
qa["validated_answers"]
)
for candidate in candidates:
answer = cls.parse_answer(candidate)
if answer in answers_set:
......@@ -100,15 +105,17 @@ class DROP(Task):
return (str(answer["number"]),)
if answer["spans"] != []:
return tuple(answer["spans"])
return (" ".join([answer["date"]["day"],
answer["date"]["month"],
answer["date"]["year"]]).strip(),)
return (
" ".join(
[answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]]
).strip(),
)
def doc_to_text(self, doc):
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
# def doc_to_text(self, doc):
# return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
def doc_to_target(self, doc):
return " " + ", ".join(doc["answers"][0])
# def doc_to_target(self, doc):
# return " " + ", ".join(doc["answers"][0])
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
......@@ -134,7 +141,13 @@ class DROP(Task):
:param results:
The results of the requests created in construct_requests.
"""
preds, golds = results, doc["answers"]
pred = results[0].strip()
target = self.doc_to_target(doc).strip()
preds = [pred]
golds = [target]
max_em = 0
max_f1 = 0
for gold_answer in golds:
......@@ -142,10 +155,7 @@ class DROP(Task):
if gold_answer[0].strip():
max_em = max(max_em, exact_match)
max_f1 = max(max_f1, f1_score)
return {
"em": max_em,
"f1": max_f1
}
return {"em": max_em, "f1": max_f1}
def get_metrics(self, predicted, gold):
"""
......@@ -158,7 +168,9 @@ class DROP(Task):
predicted_bags = self._answer_to_bags(predicted)
gold_bags = self._answer_to_bags(gold)
if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]):
if set(predicted_bags[0]) == set(gold_bags[0]) and len(
predicted_bags[0]
) == len(gold_bags[0]):
exact_match = 1.0
else:
exact_match = 0.0
......@@ -190,7 +202,9 @@ class DROP(Task):
for gold_index, gold_item in enumerate(gold):
for pred_index, pred_item in enumerate(predicted):
if self._match_numbers_if_present(gold_item, pred_item):
scores[gold_index, pred_index] = self._compute_f1(pred_item, gold_item)
scores[gold_index, pred_index] = self._compute_f1(
pred_item, gold_item
)
row_ind, col_ind = linear_sum_assignment(-scores)
max_scores = np.zeros([max(len(gold), len(predicted))])
......@@ -256,7 +270,11 @@ class DROP(Task):
def _normalize(self, answer):
tokens = [
self._white_space_fix(self._remove_articles(self._fix_number(self._remove_punc(token.lower()))))
self._white_space_fix(
self._remove_articles(
self._fix_number(self._remove_punc(token.lower()))
)
)
for token in self._tokenize(answer)
]
tokens = [token for token in tokens if token.strip()]
......@@ -269,10 +287,7 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"em": mean,
"f1": mean
}
return {"em": mean, "f1": mean}
def higher_is_better(self):
"""
......@@ -280,7 +295,4 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"em": True,
"f1": True
}
return {"em": True, "f1": True}
......@@ -40,7 +40,7 @@ class RACE(Task):
DATASET_NAME = "high"
cache = {}
letter_to_num = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
letter_to_num = {"A": 0, "B": 1, "C": 2, "D": 3}
def has_training_docs(self):
return True
......@@ -59,17 +59,27 @@ class RACE(Task):
# is shown that one document is made per passage.
r = collections.defaultdict(list)
for item in datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)[set]:
r[item['article']].append(item)
res = list(r.values() >> each(lambda x: {
'article': x[0]['article'],
'problems': x >> each(lambda y: {
'question': y['question'],
'answer': y['answer'],
'options': y['options'],
})
}))
for item in datasets.load_dataset(
path=self.DATASET_PATH, name=self.DATASET_NAME
)[set]:
r[item["article"]].append(item)
res = list(
r.values()
>> each(
lambda x: {
"article": x[0]["article"],
"problems": x
>> each(
lambda y: {
"question": y["question"],
"answer": y["answer"],
"options": y["options"],
}
),
}
)
)
self.cache[set] = res
return res
......@@ -85,45 +95,44 @@ class RACE(Task):
@classmethod
def get_answer_option(cls, problem):
answer = cls.letter_to_num[problem['answer']]
return problem['options'][answer]
answer = cls.letter_to_num[problem["answer"]]
return problem["options"][answer]
@classmethod
def last_problem(cls, doc):
return doc['problems'][-1]
def doc_to_text(self, doc):
text = 'Article: ' + doc['article'] + '\n\n'
for problem in doc['problems'][:-1]:
if problem['question'][-6:] == ' _ .':
text += problem['question'][-5:] + self.get_answer_option(problem) + '\n'
else:
question = 'Question: ' + problem['question'] + '\n'
answer = 'Answer: ' + self.get_answer_option(problem) + '\n'
text += question + answer
text += self.last_problem(doc)['question']
return text
def doc_to_target(self, doc):
return " " + self.get_answer_option(self.last_problem(doc))
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
problem = self.last_problem(doc)
ll_choices = [
rf.loglikelihood(ctx, " " + problem['options'][i])[0]
for i in range(4)
]
return ll_choices
return doc["problems"][-1]
# def doc_to_text(self, doc):
# text = 'Article: ' + doc['article'] + '\n\n'
# for problem in doc['problems'][:-1]:
# if problem['question'][-6:] == ' _ .':
# text += problem['question'][-5:] + self.get_answer_option(problem) + '\n'
# else:
# question = 'Question: ' + problem['question'] + '\n'
# answer = 'Answer: ' + self.get_answer_option(problem) + '\n'
# text += question + answer
# text += self.last_problem(doc)['question']
# return text
# def doc_to_target(self, doc):
# return " " + self.get_answer_option(self.last_problem(doc))
# def construct_requests(self, doc, ctx):
# """Uses RequestFactory to construct Requests and returns an iterable of
# Requests which will be sent to the LM.
# :param doc:
# The document as returned from training_docs, validation_docs, or test_docs.
# :param ctx: str
# The context string, generated by fewshot_context. This includes the natural
# language description, as well as the few shot examples, and the question
# part of the document for `doc`.
# """
# problem = self.last_problem(doc)
# ll_choices = [
# rf.loglikelihood(ctx, " " + problem["options"][i])[0] for i in range(4)
# ]
# return ll_choices
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......@@ -135,11 +144,11 @@ class RACE(Task):
:param results:
The results of the requests created in construct_requests.
"""
gold = self.letter_to_num[self.last_problem(doc)['answer']]
#
gold = self.letter_to_num[self.doc_to_target(doc)]
# gold = self.letter_to_num[self.last_problem(doc)["answer"]]
pred = np.argmax(results)
return {
"acc": int(pred == gold)
}
return {"acc": int(pred == gold)}
def aggregation(self):
"""
......@@ -147,9 +156,7 @@ class RACE(Task):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"acc": mean
}
return {"acc": mean}
def higher_is_better(self):
"""
......@@ -157,6 +164,4 @@ class RACE(Task):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"acc": True
}
return {"acc": True}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment