Commit 4d147bdd authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Merge branch 'master' of https://github.com/EleutherAI/lm-evaluation-harness into task-guide

parents 011cc891 dc937d4b
[run]
# tasks that aren't wired up.
omit =
lm_eval/tasks/quac.py
lm_eval/tasks/storycloze.py
lm_eval/tasks/cbt.py
lm_eval/tasks/sat.py
lm_eval/tasks/triviaqa.py
lm_eval/tasks/naturalqs.py
lm_eval/models/dummy.py
[report]
exclude_lines =
# Skip any pass lines such as may be used for @abstractmethod
pass
# Have to re-enable the standard pragma
pragma: no cover
# Don't complain about missing debug-only code:
def __repr__
if self\.debug
# Don't complain if tests don't hit defensive assertion code:
raise AssertionError
raise NotImplementedError
return NotImplemented
\ No newline at end of file
# This workflow will install Python dependencies, run tests and lint with a single version of Python # This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Python application name: Build
on: on:
push: push:
...@@ -21,10 +21,9 @@ jobs: ...@@ -21,10 +21,9 @@ jobs:
with: with:
# A list of files, directories, and wildcard patterns to cache and restore # A list of files, directories, and wildcard patterns to cache and restore
path: | path: |
data
~/.cache ~/.cache
# An explicit key for restoring and saving the cache # An explicit key for restoring and saving the cache
key: evaldata-cache key: evaldata-cache-3
- name: Set up Python 3.9 - name: Set up Python 3.9
uses: actions/setup-python@v2 uses: actions/setup-python@v2
with: with:
...@@ -46,4 +45,4 @@ jobs: ...@@ -46,4 +45,4 @@ jobs:
pytest --cov=lm_eval/ tests/ pytest --cov=lm_eval/ tests/
- name: Upload to codecov - name: Upload to codecov
run: | run: |
bash <(curl -s https://codecov.io/bash) bash <(curl -s https://codecov.io/bash) -t $CODECOV_TOKEN
\ No newline at end of file
...@@ -2,3 +2,4 @@ env ...@@ -2,3 +2,4 @@ env
*.pyc *.pyc
data/ data/
lm_cache lm_cache
.idea
\ No newline at end of file
@software{eval-harness,
author = {Gao, Leo and
Tow, Jonathan and
Biderman, Stella and
Black, Sid and
DiPofi, Anthony and
Foster, Charles and
Golding, Laurence and
Hsu, Jeffrey and
McDonell, Kyle and
Muennighoff, Niklas and
Phang, Jason and
Reynolds, Laria and
Tang, Eric and
Thite, Anish and
Wang, Ben and
Wang, Kevin and
Zou, Andy},
title = {A framework for few-shot language model evaluation},
month = sep,
year = 2021,
publisher = {Zenodo},
version = {v0.0.1},
doi = {10.5281/zenodo.5371628},
url = {https://doi.org/10.5281/zenodo.5371628}
}
This diff is collapsed.
# NLP generally do not require separately downloading data
#coqa
mkdir -p data/coqa
wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O data/coqa/coqa-train-v1.0.json
wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O data/coqa/coqa-dev-v1.0.json
#drop
mkdir -p data/drop
wget https://s3-us-west-2.amazonaws.com/allennlp/datasets/drop/drop_dataset.zip -O data/drop.zip
unzip data/drop.zip -d data/drop
rm data/drop.zip
mv data/drop/drop_dataset/* data/drop
rm -rf data/drop/drop_dataset
import abc import abc
import random import random
import numpy as np import numpy as np
import re
from lm_eval.metrics import mean from lm_eval.metrics import mean, perplexity, weighted_perplexity, weighted_mean
class LM(abc.ABC): class LM(abc.ABC):
def __init__(self):
self.cache_hook = CacheHook(None)
@abc.abstractmethod @abc.abstractmethod
def loglikelihood(self, requests): def loglikelihood(self, requests):
"""Compute log-likelihood of generating a continuation from a context. """Compute log-likelihood of generating a continuation from a context.
...@@ -24,9 +28,51 @@ class LM(abc.ABC): ...@@ -24,9 +28,51 @@ class LM(abc.ABC):
:return: list :return: list
A list of pairs (logprob, isgreedy) A list of pairs (logprob, isgreedy)
logprob: float logprob: float
The log probability of `contination` The log probability of `continuation`
isgreedy:
Whether `continuation` would be generated by greedy sampling from `context`
"""
pass
@abc.abstractmethod
def loglikelihood_rolling(self, requests):
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
the max context length.
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementaitons
which may simply concatenate multiple documents together.
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
multiple chunks, the last input will still a full-sized context.
Example:
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
Prefix: EOT
Max context length: 4
Resulting input/prediction pairs:
INPUT: EOT 0 1 2
PRED: 0 1 2 3
INPUT: 3 4 5 6
PRED: 4 5 6 7
INPUT: 5 6 7 8
PRED: 8 9
Observe that:
1. Each token is predicted exactly once
2. For the last pair, we provide the full context, but only score the last two tokens
:param requests: list
A list of strings
string: str
String for which we are computing per-toke loglikelihood
:return: list
A list of pairs (logprob, isgreedy)
logprob: float
The log probability of `continuation`
isgreedy: isgreedy:
Whether `contination` would be generated by greedy sampling from `context` Whether `continuation` would be generated by greedy sampling from `context`
""" """
pass pass
...@@ -60,6 +106,9 @@ class LM(abc.ABC): ...@@ -60,6 +106,9 @@ class LM(abc.ABC):
""" """
return cls() return cls()
def set_cache_hook(self, cache_hook):
self.cache_hook = cache_hook
class Task(abc.ABC): class Task(abc.ABC):
"""A task represents an entire benchmark including its dataset, problems, """A task represents an entire benchmark including its dataset, problems,
...@@ -189,7 +238,7 @@ class Task(abc.ABC): ...@@ -189,7 +238,7 @@ class Task(abc.ABC):
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd) fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else: else:
if self._fewshot_docs is None: if self._fewshot_docs is None:
self._fewshot_docs = list(self.validation_docs() if self.has_validation_docs else self.test_docs()) self._fewshot_docs = list(self.validation_docs() if self.has_validation_docs() else self.test_docs())
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1) fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
...@@ -220,25 +269,91 @@ class MultipleChoiceTask(Task): ...@@ -220,25 +269,91 @@ class MultipleChoiceTask(Task):
gold = doc["gold"] gold = doc["gold"]
acc = 1. if np.argmax(results) == gold else 0. acc = 1. if np.argmax(results) == gold else 0.
completion_len = np.array([float(len(i)) for i in doc["choices"]])
acc_norm = 1. if np.argmax(results / completion_len) == gold else 0.
return { return {
"acc": acc "acc": acc,
"acc_norm": acc_norm,
} }
def higher_is_better(self): def higher_is_better(self):
return { return {
"acc": True "acc": True,
"acc_norm": True,
} }
def aggregation(self): def aggregation(self):
return { return {
"acc": mean "acc": mean,
"acc_norm": mean,
} }
class PerplexityTask(Task, abc.ABC):
def has_training_docs(self):
return False
def fewshot_description(self):
return ""
def fewshot_examples(self, k, rnd):
assert k == 0
return []
def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
assert num_fewshot == 0
assert not provide_description
return ""
def higher_is_better(self):
return {
"word_perplexity": False,
"byte_perplexity": False,
"bits_per_byte": False,
}
def doc_to_text(self, doc):
return ""
def doc_to_target(self, doc):
return doc
def construct_requests(self, doc, ctx):
assert not ctx
req = rf.loglikelihood_rolling(self.doc_to_target(doc))
return req
def process_results(self, doc, results):
loglikelihood, = results
words = self.count_words(doc)
bytes = self.count_bytes(doc)
return {
"word_perplexity": (loglikelihood, words),
"byte_perplexity": (loglikelihood, bytes),
"bits_per_byte": (-loglikelihood, self.count_bytes(doc))
}
def aggregation(self):
return {
"word_perplexity": weighted_perplexity,
"byte_perplexity": weighted_perplexity,
"bits_per_byte": weighted_mean
}
def count_bytes(self, doc):
return len(doc.encode("utf-8"))
def count_words(self, doc):
""" Downstream tasks with custom word boundaries should override this! """
return len(re.split(r"\s+", doc))
req_ret_lens = { req_ret_lens = {
'loglikelihood': 2, 'loglikelihood': 2,
'greedy_until': None, 'greedy_until': None,
'loglikelihood_rolling': None,
} }
import os import os
...@@ -251,13 +366,31 @@ def hash_args(attr, args): ...@@ -251,13 +366,31 @@ def hash_args(attr, args):
return hashlib.sha256(dat.encode('utf-8')).hexdigest() return hashlib.sha256(dat.encode('utf-8')).hexdigest()
class CacheHook:
def __init__(self, cachinglm):
if cachinglm is None:
self.dbdict = None
return
self.dbdict = cachinglm.dbdict
def add_partial(self, attr, req, res):
if self.dbdict is None:
return
hsh = hash_args(attr, req)
self.dbdict[hsh] = res
class CachingLM: class CachingLM:
def __init__(self, lm, cache_db): def __init__(self, lm, cache_db):
self.lm = lm self.lm = lm
self.cache_db = cache_db self.cache_db = cache_db
os.makedirs(os.path.dirname(cache_db), exist_ok=True) if os.path.dirname(cache_db): os.makedirs(os.path.dirname(cache_db), exist_ok=True)
self.dbdict = SqliteDict(cache_db, autocommit=True) self.dbdict = SqliteDict(cache_db, autocommit=True)
# add hook to lm
lm.set_cache_hook(self.get_cache_hook())
def __getattr__(self, attr): def __getattr__(self, attr):
def fn(requests): def fn(requests):
res = [] res = []
...@@ -293,6 +426,9 @@ class CachingLM: ...@@ -293,6 +426,9 @@ class CachingLM:
return res return res
return fn return fn
def get_cache_hook(self):
return CacheHook(self)
class Request: class Request:
......
import collections import collections
import itertools import itertools
import random import random
import lm_eval.metrics
def evaluate(lm, task_dict, provide_description, num_fewshot, limit): def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_iters=100000):
# TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces
task_dict_items = [(name, task) for name, task in task_dict.items() if(task.has_validation_docs() or task.has_test_docs())] task_dict_items = [(name, task) for name, task in task_dict.items() if(task.has_validation_docs() or task.has_test_docs())]
results = collections.defaultdict(dict) results = collections.defaultdict(dict)
versions = collections.defaultdict(dict)
requests = collections.defaultdict(list) requests = collections.defaultdict(list)
requests_origin = collections.defaultdict(list) requests_origin = collections.defaultdict(list)
...@@ -23,6 +25,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit): ...@@ -23,6 +25,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
# get lists of each type of requeste # get lists of each type of requeste
for task_name, task in task_dict_items: for task_name, task in task_dict_items:
versions[task_name] = task.VERSION
#default to test doc, fall back to val doc if validation unavailable #default to test doc, fall back to val doc if validation unavailable
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
if task.has_test_docs(): if task.has_test_docs():
...@@ -48,7 +51,6 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit): ...@@ -48,7 +51,6 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
reqs = task.construct_requests(doc, ctx) reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)): reqs = [reqs] if not isinstance(reqs, (list, tuple)): reqs = [reqs]
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
requests[req.type].append(req) requests[req.type].append(req)
# i: index in requests for a single task instance # i: index in requests for a single task instance
...@@ -64,6 +66,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit): ...@@ -64,6 +66,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
# only in index. We could implement some kind of caching, but that would be more of a bandaid # only in index. We could implement some kind of caching, but that would be more of a bandaid
# solution. we could also implement some kind of autogrouping here; they should end up next to each other. # solution. we could also implement some kind of autogrouping here; they should end up next to each other.
print("Running", reqtype, "requests")
resps = getattr(lm, reqtype)([req.args for req in reqs]) resps = getattr(lm, reqtype)([req.args for req in reqs])
resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)] resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]
...@@ -89,5 +92,14 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit): ...@@ -89,5 +92,14 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
for (task_name, metric), items in vals.items(): for (task_name, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items) results[task_name][metric] = task.aggregation()[metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
stderr = lm_eval.metrics.stderr_for_metric(task.aggregation()[metric], bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters)
if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items)
return results return {
\ No newline at end of file "results": results,
"versions": versions
}
import math import math
from collections import Iterable from collections import Iterable
from pprint import pprint
import numpy as np import numpy as np
import sacrebleu import sacrebleu
import sklearn import sklearn
import random
def mean(arr): def mean(arr):
return sum(arr) / len(arr) return sum(arr) / len(arr)
def pop_stddev(arr):
mu = mean(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
def sample_stddev(arr):
mu = mean(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
def mean_stderr(arr):
return sample_stddev(arr) / math.sqrt(len(arr))
def median(arr): def median(arr):
return arr[len(arr) // 2] return arr[len(arr) // 2]
...@@ -47,6 +63,23 @@ def acc_all(items): ...@@ -47,6 +63,23 @@ def acc_all(items):
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()]) acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
return acc return acc
def acc_all_stderr(items):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {}
preds = list(zip(*items))[0]
docs = list(zip(*items))[1]
for doc, pred in zip(docs, preds):
question_id = doc["idx"]["question"]
if question_id not in question_scoring_dict:
question_scoring_dict[question_id] = []
gold_label = doc["label"] == 1
question_scoring_dict[question_id].append(gold_label == pred)
acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()])
return acc
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
"""Compute max metric between prediction and each ground truth.""" """Compute max metric between prediction and each ground truth."""
...@@ -61,6 +94,14 @@ def perplexity(items): ...@@ -61,6 +94,14 @@ def perplexity(items):
return math.exp(-mean(items)) return math.exp(-mean(items))
def weighted_mean(items):
a, b = zip(*items)
return sum(a) / sum(b)
def weighted_perplexity(items):
return math.exp(-weighted_mean(items))
def bleu(items): def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching for evaluating a generated sentence to a reference sentence. It counts matching
...@@ -124,7 +165,7 @@ def _sacreformat(refs, preds): ...@@ -124,7 +165,7 @@ def _sacreformat(refs, preds):
# Must become List[List[str]] with the inner list corresponding to preds # Must become List[List[str]] with the inner list corresponding to preds
if not is_non_str_iterable(refs): if not is_non_str_iterable(refs):
refs = list(refs) refs = list(refs)
if not is_non_str_iterable(refs): if not is_non_str_iterable(refs[0]):
refs = [[ref] for ref in refs] refs = [[ref] for ref in refs]
refs = list(zip(*refs)) refs = list(zip(*refs))
# Note the number of refs in each ref list much match the number of preds # Note the number of refs in each ref list much match the number of preds
...@@ -137,3 +178,62 @@ def _sacreformat(refs, preds): ...@@ -137,3 +178,62 @@ def _sacreformat(refs, preds):
preds = [pred[0] for pred in preds] preds = [pred[0] for pred in preds]
return refs, preds return refs, preds
## stderr stuff
class _bootstrap_internal:
def __init__(self, f, n):
self.f = f
self.n = n
def __call__(self, v):
i, xs = v
rnd = random.Random()
rnd.seed(i)
res = []
for _ in range(self.n):
res.append(self.f(rnd.choices(xs, k=len(xs))))
return res
def bootstrap_stderr(f, xs, iters):
import multiprocessing as mp
pool = mp.Pool(mp.cpu_count())
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
# equivalent to stderr calculated without Bessel's correction in the stddev.
# Unfortunately, I haven't been able to figure out what the right correction is
# to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
# that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
# Thankfully, shouldn't matter because our samples are pretty big usually anyways
res = []
chunk_size = min(1000, iters)
from tqdm import tqdm
print("bootstrapping for stddev:", f.__name__)
for bootstrap in tqdm(pool.imap(_bootstrap_internal(f, chunk_size), [(i, xs) for i in range(iters // chunk_size)]), total=iters // chunk_size):
# sample w replacement
res.extend(bootstrap)
pool.close()
return sample_stddev(res)
def stderr_for_metric(metric, bootstrap_iters):
bootstrappable = [
median,
matthews_corrcoef,
f1_score,
perplexity,
bleu,
chrf,
ter,
]
if metric in bootstrappable:
return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
stderr = {
mean: mean_stderr,
acc_all: acc_all_stderr
}
return stderr.get(metric, None)
...@@ -26,3 +26,11 @@ class DummyLM(LM): ...@@ -26,3 +26,11 @@ class DummyLM(LM):
assert ctx.strip() != '' assert ctx.strip() != ''
return res return res
def loglikelihood_rolling(self, requests):
res = []
for _ in requests:
res.append(-random.random())
return res
\ No newline at end of file
import transformers import transformers
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from lm_eval.base import LM from lm_eval.base import LM
from lm_eval import utils from lm_eval import utils
from tqdm import tqdm from tqdm import tqdm
import numpy as np
class GPT2LM(LM): class GPT2LM(LM):
MAX_GEN_TOKS = 256 MAX_GEN_TOKS = 256
VOCAB_SIZE = 50257
EOT_TOKEN_ID = 50256
def __init__(self, device=None, pretrained='gpt2'): def __init__(self, device='cuda', pretrained='gpt2', batch_size=1):
super().__init__()
if device: if device:
self.device = torch.device(device) self.device = torch.device(device)
else: else:
...@@ -20,56 +25,171 @@ class GPT2LM(LM): ...@@ -20,56 +25,171 @@ class GPT2LM(LM):
# pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2 # pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
self.tokenizer.pad_token = "<|endoftext|>" self.tokenizer.pad_token = "<|endoftext|>"
self.max_length = self.gpt2.config.n_ctx try:
self.max_length = self.gpt2.config.n_ctx
except AttributeError:
# gptneoconfig doesn't have n_ctx apparantly
self.max_length = self.gpt2.config.max_position_embeddings
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373] assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
# multithreading and batching
gpus = torch.cuda.device_count()
batch_size_per_gpu = batch_size # todo: adaptive batch size
self.batch_size = batch_size_per_gpu * gpus
# TODO: fix multi-gpu
# if gpus > 1:
# self.gpt2 = nn.DataParallel(self.gpt2)
@classmethod @classmethod
def create_from_arg_string(cls, arg_string): def create_from_arg_string(cls, arg_string, additional_config={}):
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
return cls(device=args.get("device", None), pretrained=args.get("pretrained", "gpt2")) args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(**args, **args2)
def loglikelihood(self, requests): def loglikelihood(self, requests):
new_reqs = []
for context, continuation in requests:
if context == "":
# end of text as context
context_enc = [self.EOT_TOKEN_ID]
else:
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests):
# TODO: Implement caching once we've confirmed the perplexity implementation
# TODO: automatic batch size detection for vectorization
loglikelihoods = []
with torch.no_grad():
for string, in tqdm(requests):
rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
token_list=self.tokenizer.encode(string),
prefix_token=self.EOT_TOKEN_ID,
max_seq_len=self.max_length,
context_len=1,
)))
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for that
string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True)
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = [] res = []
with torch.no_grad(): with torch.no_grad():
# TODO: vectorize properly
# TODO: automatic batch size detection for vectorization
def _collate(x): def _collate(x):
toks = self.tokenizer.encode(x[0] + x[1])[:-1] # the negative sign on len(toks) sorts descending - this has a few advantages:
return (len(toks), self.tokenizer.decode(toks)) # - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch padded context length.
# this is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = x[1] + x[2]
return (-len(toks), tuple(toks))
# TODO: automatic (variable) batch size detection for vectorization
reord = utils.Reorderer(requests, _collate) reord = utils.Reorderer(requests, _collate)
for context, continuation in tqdm(reord.get_reordered()): for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
# when too long to fit in context, truncate from the left inps = []
contlens = []
if context == "": inplens = []
# end of text as context
context_enc = [50256]
else:
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
cont_toks = inp[:, ctxlen:] # [batch, seq] padding_length = None
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
greedy_tokens = logits.argmax(dim=-1) # tensors, then we pack them together into a batch, call the model, and then pick it all apart
max_equal = (greedy_tokens == cont_toks).all() # again because vectorizing is annoying
for _, context_enc, continuation_enc in chunk:
# sanity check
assert len(context_enc) > 0
assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length
last_token_slice = logits[:, -1, :].squeeze(0).tolist() # how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.VOCAB_SIZE] slice
# cont_toks 4 5 6 7 8 9
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq] # when too long to fit in context, truncate from the left
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length+1):][:-1]
, dtype=torch.long).to(self.device)
inplen, = inp.shape
res.append((float(logits[:, :-1].sum() if logits.shape[-1] > 1 else 0), last_token_slice, bool(max_equal))) cont = continuation_enc
# optimization: if two requests have everything the same except the last token, use # since in _collate we make sure length is descending, the longest is always the first one.
# last token distribution to save compute padding_length = padding_length if padding_length is not None else inplen
lasttoks = [self.tokenizer.encode(x[1])[-1] for x in requests]
return [(l + lts[lasttok], m) for (l, lts, m), lasttok in zip(reord.get_original(res), lasttoks)] # pad to length
inp = torch.cat([
inp, # [seq]
torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq]
], dim=0)
inps.append(inp.unsqueeze(0))
contlens.append(cont)
inplens.append(inplen)
multi_logits = F.log_softmax(self._model_call(torch.cat(inps, dim=0)), dim=-1).cpu() # [batch, seq, vocab]
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(chunk, multi_logits, inps, inplens, contlens):
contlen = len(cont_toks)
logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab]
greedy_tokens = logits.argmax(dim=-1)
# cont_toks :: [1, seq]
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0)
max_equal = (greedy_tokens == cont_toks).all()
#last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]
answer = (float(logits.sum()), bool(max_equal))
# partial caching
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
res.append(answer)
return reord.get_original(res)
def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits retuned from the model
"""
return self.gpt2(inps)[0][:, :, :50257]
def greedy_until(self, requests): def greedy_until(self, requests):
# TODO: implement fully general `until` that handles untils that are # TODO: implement fully general `until` that handles untils that are
...@@ -101,6 +221,9 @@ class GPT2LM(LM): ...@@ -101,6 +221,9 @@ class GPT2LM(LM):
for term in until: for term in until:
s = s.split(term)[0] s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s) res.append(s)
return reord.get_original(res) return reord.get_original(res)
import os import os
import numpy as np
import transformers import transformers
from lm_eval.base import LM from lm_eval.base import LM
from lm_eval import utils from lm_eval import utils
...@@ -48,6 +49,7 @@ class GPT3LM(LM): ...@@ -48,6 +49,7 @@ class GPT3LM(LM):
:param truncate: bool :param truncate: bool
Truncate input if too long (if False and input is too long, throw error) Truncate input if too long (if False and input is too long, throw error)
""" """
super().__init__()
import openai import openai
self.engine = engine self.engine = engine
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
...@@ -57,16 +59,57 @@ class GPT3LM(LM): ...@@ -57,16 +59,57 @@ class GPT3LM(LM):
self.tokenizer.pad_token = "<|endoftext|>" self.tokenizer.pad_token = "<|endoftext|>"
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373] assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
self.truncate = truncate self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])[0]
# Read from environment variable OPENAI_API_SECRET_KEY # Read from environment variable OPENAI_API_SECRET_KEY
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"] openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
@classmethod @classmethod
def create_from_arg_string(cls, arg_string): def create_from_arg_string(cls, arg_string, additional_config={}):
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
return cls(engine=args.get("engine", "davinci")) args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(**args, **args2)
def loglikelihood(self, requests): def loglikelihood(self, requests):
new_reqs = []
for context, continuation in requests:
if context == "":
# end of text as context
context_enc = [50256]
else:
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests):
# TODO: switch implementation to use _loglikelihood_tokens rather than having it do its own thing
loglikelihoods = []
for string, in tqdm(requests):
encoded = self.tokenizer.encode_plus(string)["input_ids"]
rolling_token_windows = utils.get_rolling_token_windows(
token_list=encoded,
prefix_token=self.end_of_text_token_id,
max_seq_len=self.MAX_LENGTH,
context_len=1,
)
string_loglikelihoods = []
for input_tokens, pred_tokens in rolling_token_windows:
block_output = self.get_token_logprobs(
input_tokens=input_tokens,
pred_tokens=pred_tokens,
)
string_loglikelihoods.append(block_output["logprobs"])
string_loglikelihoods = np.concatenate(string_loglikelihoods).sum()
loglikelihoods.append(string_loglikelihoods)
return loglikelihoods
def _loglikelihood_tokens(self, requests):
import openai import openai
res = [] res = []
...@@ -74,22 +117,15 @@ class GPT3LM(LM): ...@@ -74,22 +117,15 @@ class GPT3LM(LM):
# this doesn't efficiently handle last-token differences yet, but those are kinda annoying because # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
# it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
# we care about and so we need some kind of backup for when it isn't # we care about and so we need some kind of backup for when it isn't
toks = self.tokenizer.encode(x[0] + x[1]) toks = x[1] + x[2]
return (len(toks), self.tokenizer.decode(toks)) return (-len(toks), tuple(toks))
reord = utils.Reorderer(requests, _collate) reord = utils.Reorderer(requests, _collate)
for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))): for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
inps = [] inps = []
ctxlens = [] ctxlens = []
for context, continuation in chunk: for cache_key, context_enc, continuation_enc in chunk:
if context == "":
# end of text as context
context_enc = [50256]
else:
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:] inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:]
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.MAX_LENGTH) ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.MAX_LENGTH)
...@@ -104,11 +140,38 @@ class GPT3LM(LM): ...@@ -104,11 +140,38 @@ class GPT3LM(LM):
logprobs=10, logprobs=10,
) )
for resp, ctxlen in zip(response.choices, ctxlens): for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(response.choices, ctxlens, chunk):
res.append(get_result(resp, ctxlen)) answer = get_result(resp, ctxlen)
res.append(answer)
# partial caching
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return reord.get_original(res) return reord.get_original(res)
def get_token_logprobs(self, input_tokens, pred_tokens):
pred_start = len(input_tokens) - len(pred_tokens) + 1
# We're going to stitch together the input_tokens and pred_tokens
# In the longest case, this gets us to length = max_seq_len+1 (which the API works with)
assert input_tokens[pred_start:] == pred_tokens[:-1]
token_ids = input_tokens + [pred_tokens[-1]]
response = oa_completion(
engine=self.engine,
prompt=token_ids,
max_tokens=0,
temperature=0.0,
logprobs=0,
echo=True,
)
logprobs = np.array(response["choices"][0]["logprobs"]["token_logprobs"][pred_start:])
positions = np.arange(pred_start-1, pred_start-1 + len(token_ids[pred_start:]))
return {
"logprobs": logprobs,
"positions": positions,
}
def greedy_until(self, requests): def greedy_until(self, requests):
if not requests: return [] if not requests: return []
import openai import openai
...@@ -149,13 +212,15 @@ class GPT3LM(LM): ...@@ -149,13 +212,15 @@ class GPT3LM(LM):
stop=until stop=until
) )
for resp in response.choices: for resp, (context, until) in zip(response.choices, chunk):
s = resp['text'] s = resp['text']
for term in until: for term in until:
s = s.split(term)[0] s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s) res.append(s)
return reord.get_original(res) return reord.get_original(res)
...@@ -21,6 +21,8 @@ from . import arithmetic ...@@ -21,6 +21,8 @@ from . import arithmetic
from . import lambada from . import lambada
from . import race from . import race
from . import piqa from . import piqa
from . import prost
from . import mc_taco
from . import triviaqa from . import triviaqa
from . import pubmedqa from . import pubmedqa
from . import sciq from . import sciq
...@@ -29,12 +31,18 @@ from . import qa4mre ...@@ -29,12 +31,18 @@ from . import qa4mre
from . import translation from . import translation
from . import headqa from . import headqa
from . import mathqa from . import mathqa
from . import ethics from . import hendrycks_ethics
from . import drop from . import drop
from . import unscramble from . import unscramble
from . import logiqa from . import logiqa
from . import hendrycks_test from . import hendrycks_test
from . import math from . import hendrycks_math
from . import cbt
from . import lambada_cloze
from . import pile
from . import wikitext
from . import lambada_multilingual
from . import mutual
######################################## ########################################
# Translation tasks # Translation tasks
...@@ -91,25 +99,36 @@ TASK_REGISTRY = { ...@@ -91,25 +99,36 @@ TASK_REGISTRY = {
"coqa": coqa.CoQA, "coqa": coqa.CoQA,
"drop": drop.DROP, "drop": drop.DROP,
"lambada": lambada.LAMBADA, "lambada": lambada.LAMBADA,
"lambada_cloze": lambada_cloze.LAMBADA_cloze,
# multilingual lambada
**lambada_multilingual.construct_tasks(),
"wikitext": wikitext.WikiText,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix
"piqa": piqa.PiQA, "piqa": piqa.PiQA,
"prost": prost.PROST,
"mc_taco": mc_taco.MCTACO,
# Science related # Science related
"pubmedqa" : pubmedqa.Pubmed_QA, "pubmedqa" : pubmedqa.Pubmed_QA,
"sciq" : sciq.SciQ, "sciq" : sciq.SciQ,
#"qa4mre" : qa4mre.QA4MRE,
"qa4mre_2011" : qa4mre.QA4MRE_2011, "qa4mre_2011" : qa4mre.QA4MRE_2011,
"qa4mre_2012" : qa4mre.QA4MRE_2012, "qa4mre_2012" : qa4mre.QA4MRE_2012,
"qa4mre_2013" : qa4mre.QA4MRE_2013, "qa4mre_2013" : qa4mre.QA4MRE_2013,
#"triviaqa": triviaqa.TriviaQA, "triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy, "arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge, "arc_challenge": arc.ARCChallenge,
# "quac": quac.QuAC, # not implemented yet # "quac": quac.QuAC, # not implemented yet
"logiqa": logiqa.LogiQA, "logiqa": logiqa.LogiQA,
"hellaswag": hellaswag.HellaSwag, # not implemented yet "hellaswag": hellaswag.HellaSwag,
"openbookqa": openbookqa.OpenBookQA, "openbookqa": openbookqa.OpenBookQA,
# "sat": sat.SATAnalogies, # not implemented yet # "sat": sat.SATAnalogies, # not implemented yet
# "squad": squad.SQuAD, # not implemented yet "squad2": squad.SQuAD2,
"race": race.RACE, "race": race.RACE,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet # "naturalqs": naturalqs.NaturalQs, # not implemented yet
"headqa": headqa.HeadQA, "headqa": headqa.HeadQA,
...@@ -121,21 +140,25 @@ TASK_REGISTRY = { ...@@ -121,21 +140,25 @@ TASK_REGISTRY = {
"anli_r2": anli.ANLIRound2, "anli_r2": anli.ANLIRound2,
"anli_r3": anli.ANLIRound3, "anli_r3": anli.ANLIRound3,
"ethics_cm": ethics.EthicsCM, "ethics_cm": hendrycks_ethics.EthicsCM,
"ethics_deontology": ethics.EthicsDeontology, "ethics_deontology": hendrycks_ethics.EthicsDeontology,
"ethics_justice": ethics.EthicsJustice, "ethics_justice": hendrycks_ethics.EthicsJustice,
"ethics_utilitarianism_original": ethics.EthicsUtilitarianismOriginal, "ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal,
"ethics_utilitarianism": ethics.EthicsUtilitarianism, "ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
"ethics_virtue": ethics.EthicsVirtue, "ethics_virtue": hendrycks_ethics.EthicsVirtue,
# dialogue
"mutual": mutual.MuTual,
"mutual_plus": mutual.MuTualPlus,
# math # math
"math_algebra": math.MathAlgebra, "math_algebra": hendrycks_math.MathAlgebra,
"math_counting_and_prob": math.MathCountingAndProbability, "math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
"math_geometry": math.MathGeometry, "math_geometry": hendrycks_math.MathGeometry,
"math_intermediate_algebra": math.MathIntermediateAlgebra, "math_intermediate_algebra": hendrycks_math.MathIntermediateAlgebra,
"math_num_theory": math.MathNumberTheory, "math_num_theory": hendrycks_math.MathNumberTheory,
"math_prealgebra": math.MathPrealgebra, "math_prealgebra": hendrycks_math.MathPrealgebra,
"math_precalc": math.MathPrecalculus, "math_precalc": hendrycks_math.MathPrecalculus,
# arithmetic # arithmetic
"arithmetic_2da": arithmetic.Arithmetic2DPlus, "arithmetic_2da": arithmetic.Arithmetic2DPlus,
...@@ -165,6 +188,31 @@ TASK_REGISTRY = { ...@@ -165,6 +188,31 @@ TASK_REGISTRY = {
"cycle_letters": unscramble.CycleLetters, "cycle_letters": unscramble.CycleLetters,
"random_insertion": unscramble.RandomInsertion, "random_insertion": unscramble.RandomInsertion,
"reversed_words": unscramble.ReversedWords, "reversed_words": unscramble.ReversedWords,
# Pile
"pile_arxiv": pile.PileArxiv,
"pile_books3": pile.PileBooks3,
"pile_bookcorpus2": pile.PileBookCorpus2,
"pile_dm-mathematics": pile.PileDmMathematics,
"pile_enron": pile.PileEnron,
"pile_europarl": pile.PileEuroparl,
"pile_freelaw": pile.PileFreeLaw,
"pile_github": pile.PileGithub,
"pile_gutenberg": pile.PileGutenberg,
"pile_hackernews": pile.PileHackernews,
"pile_nih-exporter": pile.PileNIHExporter,
"pile_opensubtitles": pile.PileOpenSubtitles,
"pile_openwebtext2": pile.PileOpenWebText2,
"pile_philpapers": pile.PilePhilPapers,
"pile_pile-cc": pile.PilePileCc,
"pile_pubmed-abstracts": pile.PilePubmedAbstracts,
"pile_pubmed-central": pile.PilePubmedCentral,
"pile_stackexchange": pile.PileStackExchange,
"pile_uspto": pile.PileUspto,
"pile_ubuntu-irc": pile.PileUbuntuIrc,
"pile_wikipedia": pile.PileWikipedia,
"pile_youtubesubtitles": pile.PileYoutubeSubtitles,
} }
......
...@@ -5,6 +5,7 @@ from . common import HFTask ...@@ -5,6 +5,7 @@ from . common import HFTask
class ANLIBase(HFTask): class ANLIBase(HFTask):
VERSION = 0
DATASET_PATH = "anli" DATASET_PATH = "anli"
DATASET_NAME = None DATASET_NAME = None
SPLIT = None SPLIT = None
......
...@@ -3,6 +3,7 @@ from . common import HFTask ...@@ -3,6 +3,7 @@ from . common import HFTask
class ARCEasy(HFTask, MultipleChoiceTask): class ARCEasy(HFTask, MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "ai2_arc" DATASET_PATH = "ai2_arc"
DATASET_NAME = "ARC-Easy" DATASET_NAME = "ARC-Easy"
...@@ -28,22 +29,6 @@ class ARCEasy(HFTask, MultipleChoiceTask): ...@@ -28,22 +29,6 @@ class ARCEasy(HFTask, MultipleChoiceTask):
} }
return out_doc return out_doc
def _load_docs(self, docs):
for record in docs:
yield self._convert_standard(record)
def training_docs(self):
docs = super().training_docs()
return self._load_docs(docs)
def validation_docs(self):
docs = super().validation_docs()
return self._load_docs(docs)
def test_docs(self):
docs = super().test_docs()
return self._load_docs(docs)
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out description # TODO: figure out description
return "" return ""
......
...@@ -10,6 +10,7 @@ ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion']) ...@@ -10,6 +10,7 @@ ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion'])
class Arithmetic(Task): class Arithmetic(Task):
VERSION = 0
directory = 'data/arithmetic/' directory = 'data/arithmetic/'
def __init__(self): def __init__(self):
......
import numpy as np
from lm_eval.base import rf
from lm_eval.metrics import mean
from .common import HFTask
class CBTBase(HFTask):
"""The Children’s Book Test (CBT) from the paper:
https://research.fb.com/wp-content/uploads/2016/11/the_goldilocks_principle_reading_children_s_books_with_explicit_memory_representations.pdf
NOTE: This evaluation is based on the (context + query) question-answering variant
used by the Recurrent Language Models described in the aforementioned paper.
See section 4.4.
"""
DATASET_PATH = "cbt"
DATASET_NAME = None
VERSION = 0
def fewshot_description(self):
# TODO: Figure out description.
return ""
def detokenize(self, text):
text = text.replace(" '", "'")
text = text.replace(" \n", "\n")
text = text.replace("\n ", "\n")
text = text.replace(" n't", "n't")
text = text.replace("`` ", '"')
text = text.replace("''", '"')
# punctuation
text = text.replace(" :", ":")
text = text.replace(" ;", ";")
text = text.replace(" !", "!")
text = text.replace(" ?", "?")
text = text.replace(" ,", ",")
text = text.replace(" .", ".")
return text
def doc_to_text(self, doc):
passage = " ".join(doc["sentences"])
text = "Passage: " + passage + "\nQuestion: " + doc["question"]
return self.detokenize(text)
def doc_to_target(self, doc):
return ""
def fewshot_examples(self, k, rnd):
assert k == 0, f"CBT is only implemented for the zero-shot setting. Given k={k}."
return super().fewshot_examples(k, rnd)
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
lls = []
for option in doc["options"]:
# Following Section 4.4 "Recurrent Language Models" in the CBT paper:
# "we rank candidate [option] c based on p(q1 . . . qk−1, c, qk+1 . . . ql)
# rather than simply p(q1 . . . qk−1, c)."
lls.append(rf.loglikelihood("", ctx.replace("XXXXX", option))[0])
return lls
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
gold = doc["options"].index(doc["answer"])
pred = np.argmax(results)
return {
"acc": pred == gold
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"acc": mean
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"acc": True
}
class CBTCN(CBTBase):
DATASET_NAME = "CN"
class CBTNE(CBTBase):
DATASET_NAME = "NE"
import datasets import datasets
import lm_eval.metrics
from ..base import Task from ..base import Task
...@@ -26,30 +25,24 @@ class HFTask(Task): ...@@ -26,30 +25,24 @@ class HFTask(Task):
"""Whether the task has a test set""" """Whether the task has a test set"""
return True if "test" in self.data.keys() else False return True if "test" in self.data.keys() else False
def _convert_standard(self, doc):
return doc
def training_docs(self): def training_docs(self):
# Cache training for faster few-shot. # Cache training for faster few-shot.
# If data is too large to fit in memory, override this method. # If data is too large to fit in memory, override this method.
if self.has_training_docs(): if self.has_training_docs():
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.data["train"]) self._training_docs = list(map(self._convert_standard, self.data["train"]))
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
if self.has_validation_docs(): if self.has_validation_docs():
return self.data["validation"] return map(self._convert_standard, self.data["validation"])
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
return self.data["test"] return map(self._convert_standard, self.data["test"])
def simple_accuracy_metric(preds, golds):
acc = float(lm_eval.metrics.mean())
return {
"major": acc,
"minor": {"acc": acc},
"higher_is_better": True,
}
def yesno(x): def yesno(x):
......
...@@ -4,19 +4,20 @@ import transformers.data.metrics.squad_metrics as squad_metrics ...@@ -4,19 +4,20 @@ import transformers.data.metrics.squad_metrics as squad_metrics
from lm_eval.base import Task, rf, mean from lm_eval.base import Task, rf, mean
from ..utils import sh from ..utils import sh
from itertools import zip_longest from itertools import zip_longest
from best_download import download_file
class CoQA(Task): class CoQA(Task):
VERSION = 1
def download(self): def download(self):
coqa_train_filepath = 'data/coqa/coqa-train-v1.0.json' coqa_train_filepath = 'data/coqa/coqa-train-v1.0.json'
coqa_dev_filepath = 'data/coqa/coqa-dev-v1.0.json' coqa_dev_filepath = 'data/coqa/coqa-dev-v1.0.json'
sh ("""mkdir -p data/coqa""") sh ("""mkdir -p data/coqa""")
if not os.path.exists(coqa_train_filepath):
sh ("""wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O """ + coqa_train_filepath) download_file("http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json", coqa_train_filepath, "b0fdb2bc1bd38dd3ca2ce5fa2ac3e02c6288ac914f241ac409a655ffb6619fa6")
if not os.path.exists(coqa_dev_filepath): download_file("http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json", coqa_dev_filepath, "dfa367a9733ce53222918d0231d9b3bedc2b8ee831a2845f62dfc70701f2540a")
sh ("""wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O """ + coqa_dev_filepath)
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -114,7 +115,7 @@ class CoQA(Task): ...@@ -114,7 +115,7 @@ class CoQA(Task):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
cont_request = rf.greedy_until(ctx, ['\n']) cont_request = rf.greedy_until(ctx, ['\nQ:'])
return cont_request return cont_request
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -129,7 +130,7 @@ class CoQA(Task): ...@@ -129,7 +130,7 @@ class CoQA(Task):
""" """
turn_id = len(doc["questions"]) turn_id = len(doc["questions"])
gold_list = self.get_answers(doc, turn_id) gold_list = self.get_answers(doc, turn_id)
pred = results[0] pred = results[0].strip().split('\n')[0]
scores = self.compute_scores(gold_list, pred) scores = self.compute_scores(gold_list, pred)
......
...@@ -14,8 +14,10 @@ Acknowledgement: This implementation is based on the official evaluation for `DR ...@@ -14,8 +14,10 @@ Acknowledgement: This implementation is based on the official evaluation for `DR
https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py
""" """
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
class DROP(Task): class DROP(Task):
VERSION = 1
DATASET_PATH = Path("data/drop") DATASET_PATH = Path("data/drop")
def download(self): def download(self):
...@@ -49,19 +51,34 @@ class DROP(Task): ...@@ -49,19 +51,34 @@ class DROP(Task):
"id": qa["query_id"], "id": qa["query_id"],
"passage": doc["passage"], "passage": doc["passage"],
"question": qa["question"], "question": qa["question"],
"answers": self.get_answers(qa["answer"]), "answers": self.get_answers(qa),
} }
@classmethod @classmethod
def get_answers(cls, answers): def get_answers(cls, qa):
# NOTE: We wrap every non-`list` answer into a list for uniformity. answers = []
if answers["number"] != "": answers_set = set()
return [str(answers["number"])]
if answers["spans"] != []: candidates = [qa["answer"]] + qa.get("validated_answers", [])
return answers["spans"] for candidate in candidates:
return [" ".join([answers["date"]["day"], answer = cls.parse_answer(candidate)
answers["date"]["month"], if answer in answers_set:
answers["date"]["year"]]).strip()] continue
answers_set.add(answer)
answers.append(answer)
return answers
@classmethod
def parse_answer(cls, answer):
# NOTE: Everything is returned as a tuple for uniformity and hashability.
if answer["number"] != "":
return (str(answer["number"]),)
if answer["spans"] != []:
return tuple(answer["spans"])
return (" ".join([answer["date"]["day"],
answer["date"]["month"],
answer["date"]["year"]]).strip(),)
def training_docs(self): def training_docs(self):
docs = json.load(open(self.DATASET_PATH / "drop_dataset" / "drop_dataset_train.json")) docs = json.load(open(self.DATASET_PATH / "drop_dataset" / "drop_dataset_train.json"))
...@@ -75,7 +92,7 @@ class DROP(Task): ...@@ -75,7 +92,7 @@ class DROP(Task):
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:" return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + ", ".join(doc["answers"]) return " " + ", ".join(doc["answers"][0])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
...@@ -88,9 +105,7 @@ class DROP(Task): ...@@ -88,9 +105,7 @@ class DROP(Task):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
conts = [] conts = [rf.greedy_until(ctx, ["."])]
for _ in doc["answers"]:
conts.append(rf.greedy_until(ctx, ["."]))
return conts return conts
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -104,66 +119,96 @@ class DROP(Task): ...@@ -104,66 +119,96 @@ class DROP(Task):
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
preds, golds = results, doc["answers"] preds, golds = results, doc["answers"]
exact_match, f1_score = self.get_metrics(preds, golds) max_em = 0
max_f1 = 0
for gold_answer in golds:
exact_match, f1_score = self.get_metrics(preds, gold_answer)
if gold_answer[0].strip():
max_em = max(max_em, exact_match)
max_f1 = max(max_f1, f1_score)
return { return {
"em": exact_match, "em": max_em,
"f1": f1_score "f1": max_f1
} }
def get_metrics(self, preds, golds): def get_metrics(self, predicted, gold):
exact_match = self._exact_match(preds, golds) """
f1_score = self._f1_score(preds, golds) Takes a predicted answer and a gold answer (that are both either a string or a list of
return exact_match, f1_score strings), and returns exact match and the DROP F1 metric for the prediction. If you are
writing a script for evaluating objects in memory (say, the output of predictions during
def _exact_match(self, preds, golds): validation, or while training), this is the function you want to call, after using
""" Returns the exact match of normalized gold answers and predictions. """ :func:`answer_json_to_strings` when reading the gold answer from the released data file.
normalized_preds = [self._normalize(pred) for pred in preds]
normalized_golds = [self._normalize(gold) for gold in golds]
is_equal_sets = set(normalized_preds) == set(normalized_golds)
is_equal_length = len(normalized_preds) == len(normalized_golds)
return int(is_equal_sets and is_equal_length)
def _f1_score(self, preds, golds):
"""Returns the average F1-score over normalized gold answers and predictions.
From Section 5 of Dua et al. "DROP:...":
"When an answer has multiple spans, we first perform a one-to-one
alignment greedily based on bag-of-word overlap on the set of spans
and then compute average F1 over each span."
""" """
pred_bags = self._answer_to_bags(preds) predicted_bags = self._answer_to_bags(predicted)
gold_bags = self._answer_to_bags(golds) gold_bags = self._answer_to_bags(gold)
f1_per_bag = self._align_bags(pred_bags, gold_bags)
return np.mean(f1_per_bag) if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]):
exact_match = 1.0
def _answer_to_bags(self, answers): else:
return [set(self._normalize(answer).split()) for answer in answers] exact_match = 0.0
def _align_bags(self, pred_bags, gold_bags): f1_per_bag = self._align_bags(predicted_bags[1], gold_bags[1])
""" Returns the max metric value over all the answers. """ f1 = np.mean(f1_per_bag)
scores = np.zeros([len(gold_bags), len(pred_bags)]) f1 = round(f1, 2)
for gold_index, gold_bag in enumerate(gold_bags): return exact_match, f1
for pred_index, pred_bag in enumerate(pred_bags):
if self._is_number_match(pred_bag, gold_bag): def _answer_to_bags(self, answer):
scores[gold_index, pred_index] = self._bag_f1(pred_bag, gold_bag) if isinstance(answer, (list, tuple)):
raw_spans = answer
else:
raw_spans = [answer]
normalized_spans = []
token_bags = []
for raw_span in raw_spans:
normalized_span = self._normalize(raw_span)
normalized_spans.append(normalized_span)
token_bags.append(set(normalized_span.split()))
return normalized_spans, token_bags
def _align_bags(self, predicted, gold):
"""
Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
between them and gets maximum metric values over all the answers.
"""
scores = np.zeros([len(gold), len(predicted)])
for gold_index, gold_item in enumerate(gold):
for pred_index, pred_item in enumerate(predicted):
if self._match_numbers_if_present(gold_item, pred_item):
scores[gold_index, pred_index] = self._compute_f1(pred_item, gold_item)
row_ind, col_ind = linear_sum_assignment(-scores) row_ind, col_ind = linear_sum_assignment(-scores)
max_scores = np.zeros([max(len(gold_bags), len(pred_bags))])
max_scores = np.zeros([max(len(gold), len(predicted))])
for row, column in zip(row_ind, col_ind): for row, column in zip(row_ind, col_ind):
max_scores[row] = max(max_scores[row], scores[row, column]) max_scores[row] = max(max_scores[row], scores[row, column])
return max_scores return max_scores
def _bag_f1(self, pred_bag, gold_bag): def _compute_f1(self, predicted_bag, gold_bag):
intersection = len(gold_bag.intersection(pred_bag)) intersection = len(gold_bag.intersection(predicted_bag))
if intersection == 0: if not predicted_bag:
return 0.0 precision = 1.0
precision = intersection / float(len(pred_bag)) if pred_bag else 1.0 else:
recall = intersection / float(len(gold_bag)) if gold_bag else 1.0 precision = intersection / float(len(predicted_bag))
f1 = (2 * precision * recall) / (precision + recall) if not gold_bag:
recall = 1.0
else:
recall = intersection / float(len(gold_bag))
f1 = (
(2 * precision * recall) / (precision + recall)
if not (precision == 0.0 and recall == 0.0)
else 0.0
)
return f1 return f1
def _is_number_match(self, pred_bag, gold_bag): def _match_numbers_if_present(self, gold_bag, predicted_bag):
pred_numbers = set([word for word in pred_bag if self._is_number(word)]) gold_numbers = set()
gold_numbers = set([word for word in gold_bag if self._is_number(word)]) predicted_numbers = set()
if (not gold_numbers) or gold_numbers.intersection(pred_numbers): for word in gold_bag:
if self._is_number(word):
gold_numbers.add(word)
for word in predicted_bag:
if self._is_number(word):
predicted_numbers.add(word)
if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
return True return True
return False return False
...@@ -174,30 +219,29 @@ class DROP(Task): ...@@ -174,30 +219,29 @@ class DROP(Task):
except ValueError: except ValueError:
return False return False
def _normalize(self, answer): def _remove_articles(self, text):
def remove_articles(text): return _ARTICLES.sub(" ", text)
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
return re.sub(regex, " ", text)
def white_space_fix(text): def _white_space_fix(self, text):
return " ".join(text.split()) return " ".join(text.split())
def remove_punc(text): def _remove_punc(self, text):
exclude = set(string.punctuation) exclude = set(string.punctuation)
if not self._is_number(text): if not self._is_number(text):
return "".join(ch for ch in text if ch not in exclude) return "".join(ch for ch in text if ch not in exclude)
else: else:
return text return text
def fix_number(text): def _fix_number(self, text):
return str(float(text)) if self._is_number(text) else text return str(float(text)) if self._is_number(text) else text
def tokenize(text): def _tokenize(self, text):
return re.split(" |-", text) return re.split(" |-", text)
def _normalize(self, answer):
tokens = [ tokens = [
white_space_fix(remove_articles(fix_number(remove_punc(token.lower())))) self._white_space_fix(self._remove_articles(self._fix_number(self._remove_punc(token.lower()))))
for token in tokenize(answer) for token in self._tokenize(answer)
] ]
tokens = [token for token in tokens if token.strip()] tokens = [token for token in tokens if token.strip()]
normalized = " ".join(tokens).strip() normalized = " ".join(tokens).strip()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment