Commit 55e62507 authored by researcher2's avatar researcher2
Browse files

Merge branch 'master' into researcher2

parents bb0eafbb 26f0233f
This diff is collapsed.
# Description Guide
![fewshot-example](./img/fewshot_example_gpt3.png)
(Figure from [Brown et al., 2020](https://arxiv.org/pdf/2005.14165.pdf))
Task descriptions provide in-context task instruction for your language model. If you'd like to prepend a natural language description to your few-shot examples and prompt, you can do so on a per-task basis via the `description_dict` arg of [`evaluator.evaluate`](../lm_eval/evaluator.py). This `description_dict` must adhere to the following key-value structure:
- **key**: the task name (`str`) as specified in the lm-eval-harness [task registry](../lm_eval/tasks/__init__.py).
- **value**: the corresponding (`str`) description/prompt for the task identified by **key**.
```python
description_dict = {
"task_name_1": "description",
"task_name_2": "description",
...
}
```
Note that a task's description will be separated from its following few-shot examples and prompt by a new line as such:
```python
"""
<description>
<examples>
<prompt>
"""
```
## Descriptions in File
One can also interface with the aforementioned [`evaluator.evaluate`](../lm_eval/evaluator.py) (or `evaluator.simple_evaluate`) method from a higher level by simply passing a JSON file path to the `description_dict_path` arg of the command-line interface (CLI) program, `main.py`. The JSON file pointed to should be structured the same as the `description_dict`. E.g. for some file at `/your/path/descriptions.json` you may have:
```json
{
"cycle_letters": "Please unscramble the letters into a word, and write that word:",
"copa": "Given a premise and one alternative with a causal relation to the premise and another without, choose the more plausible alternative"
}
```
which can then be supplied to the CLI as:
```bash
python main.py \
--tasks cycle_letters,copa \
--description_dict_path /your/path/descriptions.json \
...
```
...@@ -87,8 +87,7 @@ There are 2 standard approaches we follow for downloading data: ...@@ -87,8 +87,7 @@ There are 2 standard approaches we follow for downloading data:
``` ```
These methods return `True`/`False` whether or not your task dataset provides documents for each split type. __Note__: if the test set doesn't have publicly available labels, please do not put it down as having a test set. These methods return `True`/`False` whether or not your task dataset provides documents for each split type. __Note__: if the test set doesn't have publicly available labels, please do not put it down as having a test set.
Lastly, we need to load the documents. In our terminology, a document (`doc`) is a single natural language data example stored in a Python `dict`. E.g.: Lastly, we need to load the documents. In our terminology, a document (`doc`) is a single natural language data example stored in a Python `dict`. E.g.: `{“question”: “What is the capital of France?”, “answer”: “Paris”}`. Override the following methods to load your data splits from their storage location in `DATASET_PATH`:
`{“question”: “What is the capital of France?”, “answer”: “Paris”}`. Override the following methods to load your data splits from their storage location in `DATASET_PATH`:
```python ```python
def training_docs(self): def training_docs(self):
return #... return #...
...@@ -125,17 +124,9 @@ You can now skip ahead to <a href="#Registering-Your-Task">registering your task ...@@ -125,17 +124,9 @@ You can now skip ahead to <a href="#Registering-Your-Task">registering your task
<br> <br>
In the case your task is _not_ multiple-choice, override the following methods for your task class:
In the case your task is not multiple-choice, override the following methods for your task class: Format your document into a single query prompt __without the answer__ here. This method takes a single `doc` example of type `dict` with `str` key-value members. You should concatenate these `doc` item values together into a neatly formatted prompt.
Put the natural language task description as a single line (no `\n`s) string here. E.g. `"Translate English to French:"`
```python
def fewshot_description(self):
return ""
```
Format your document into a single query prompt __without the answer__ here. This method takes a single `doc` example (in dictionary form) . You should concatenate its members into a nicely formatted prompt.
```python ```python
def doc_to_text(self, doc): def doc_to_text(self, doc):
...@@ -161,11 +152,12 @@ After registering your task, you can now check on your data downloading and veri ...@@ -161,11 +152,12 @@ After registering your task, you can now check on your data downloading and veri
```bash ```bash
python -m scripts.write_out \ python -m scripts.write_out \
--task <your-task> \
--output_base_path <path> \ --output_base_path <path> \
--tasks <your-task> \
--sets <train | val | test> \ --sets <train | val | test> \
--num_fewshot K \ --num_fewshot K \
--num_examples N --num_examples N \
--description_dict_path <path>
``` ```
Open the file specified at the `--output_base_path <path>` and ensure it passes Open the file specified at the `--output_base_path <path>` and ensure it passes
......
This diff is collapsed.
...@@ -7,23 +7,71 @@ import lm_eval.tasks ...@@ -7,23 +7,71 @@ import lm_eval.tasks
import lm_eval.base import lm_eval.base
from scripts.clean_training_data.contamination import get_train_overlap from scripts.clean_training_data.contamination import get_train_overlap
import numpy as np import numpy as np
from lm_eval.utils import positional_deprecated
def simple_evaluate(model, model_args, task_names, num_fewshot=0, batch_size=None, device=None, @positional_deprecated
no_cache=False, limit=None, bootstrap_iters=100000, decontaminate=False, def simple_evaluate(model, model_args=None, tasks=[],
num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000,
description_dict=None, decontaminate=False,
ngrams_path=None, ngrams_n_size=None): ngrams_path=None, ngrams_n_size=None):
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
Name of model or LM object, see lm_eval.models.get_model
:param model_args: Optional[str]
String arguments for each model class, see LM.create_from_arg_string.
Ignored if `model` argument is a LM object.
:param tasks: list[Union[str, Task]]
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
:param num_fewshot: int
Number of examples in few-shot context
:param batch_size: int, optional
Batch size for model
:param device: str, optional
PyTorch device (e.g. "cpu" or "cuda:0") for running models
:param no_cache: bool
Whether or not to cache
:param limit: int, optional
Limit the number of examples per task (only use this for testing)
:param bootstrap_iters:
Number of iterations for bootstrap statistics
:param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description`
:return
Dictionary of results
"""
random.seed(1234) random.seed(1234)
np.random.seed(1234) np.random.seed(1234)
lm = lm_eval.models.MODEL_REGISTRY[model].create_from_arg_string(model_args, { assert tasks != [], "No tasks specified"
'batch_size': batch_size, 'device': device
}) if isinstance(model, str):
if model_args is None: model_args = ""
lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, {
'batch_size': batch_size, 'device': device
})
else:
assert isinstance(model, lm_eval.base.LM)
lm = model
if not no_cache: if not no_cache:
lm = lm_eval.base.CachingLM(lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db') lm = lm_eval.base.CachingLM(
lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db'
)
task_dict = lm_eval.tasks.get_task_dict(task_names) task_dict = lm_eval.tasks.get_task_dict(tasks)
results = evaluate(lm, task_dict, False, num_fewshot, limit, bootstrap_iters=bootstrap_iters,
decontaminate=decontaminate, ngrams_path=ngrams_path, ngrams_n_size=ngrams_n_size) results = evaluate(
lm=lm,
task_dict=task_dict,
num_fewshot=num_fewshot,
limit=limit,
description_dict=description_dict,
decontaminate=decontaminate,
ngrams_path=ngrams_path,
ngrams_n_size=ngrams_n_size
)
# add info about the model and few shot config # add info about the model and few shot config
results["config"] = { results["config"] = {
...@@ -34,23 +82,52 @@ def simple_evaluate(model, model_args, task_names, num_fewshot=0, batch_size=Non ...@@ -34,23 +82,52 @@ def simple_evaluate(model, model_args, task_names, num_fewshot=0, batch_size=Non
"device": device, "device": device,
"no_cache": no_cache, "no_cache": no_cache,
"limit": limit, "limit": limit,
"bootstrap_iters": bootstrap_iters "bootstrap_iters": bootstrap_iters,
"description_dict": description_dict
} }
return results return results
decontaminate_suffix = "_decontaminate" decontaminate_suffix = "_decontaminate"
def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_iters=100000, @positional_deprecated
def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None,
decontaminate=False, ngrams_path=None, ngrams_n_size=None): decontaminate=False, ngrams_path=None, ngrams_n_size=None):
assert not provide_description # not implemented. todo: implement proper description-providing system """Instantiate and evaluate a model on a list of tasks.
:param lm: obj
Language Model
:param task_dict: dict[str, Task]
Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param num_fewshot: int
Number of examples in few-shot context
:param limit: int, optional
Limit the number of examples per task (only use this for testing)
:param bootstrap_iters:
Number of iterations for bootstrap statistics
:param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description`
:return
Dictionary of results
"""
# TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces
# TODO: todo: implement proper description-providing system
assert not provide_description # not implemented.
if provide_description is not None:
# nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
if decontaminate: if decontaminate:
assert ngrams_path and ngrams_n_size assert ngrams_path and ngrams_n_size
# TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces task_dict_items = [
(name, task)
task_dict_items = [(name, task) for name, task in task_dict.items() if(task.has_validation_docs() or task.has_test_docs())] for name, task in task_dict.items()
if(task.has_validation_docs() or task.has_test_docs())
]
results = collections.defaultdict(dict) results = collections.defaultdict(dict)
versions = collections.defaultdict(dict) versions = collections.defaultdict(dict)
...@@ -60,20 +137,20 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i ...@@ -60,20 +137,20 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
overlaps = collections.defaultdict(list) # {task_name: contaminated_docs} overlaps = collections.defaultdict(list) # {task_name: contaminated_docs}
# if we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger memory, # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
# we can always modify this plumbing to support that, but i didn't want to include it just yet because overengineering is bad # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
# (or we could make it write the requests to disk and then read them back out again - probably using an sqlite db because of all the moving parts we have # over-engineering is bad (or we could make it write the requests to disk and then read them back out again
# - probably using an sqlite db because of all the moving parts we have
# TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
docs = {} docs = {}
docs_for_decontamination = collections.defaultdict(list) docs_for_decontamination = collections.defaultdict(list)
# get lists of each type of requeste # get lists of each type of request
for task_name, task in task_dict_items: for task_name, task in task_dict_items:
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
#default to test doc, fall back to val doc if validation unavailable # default to test doc, fall back to val doc if validation unavailable
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
if task.has_test_docs(): if task.has_test_docs():
task_doc_func = task.test_docs task_doc_func = task.test_docs
...@@ -81,6 +158,8 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i ...@@ -81,6 +158,8 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
elif task.has_validation_docs(): elif task.has_validation_docs():
task_set = "val" # Required for caching in the decontamination task_set = "val" # Required for caching in the decontamination
task_doc_func = task.validation_docs task_doc_func = task.validation_docs
else:
raise RuntimeError("Task has neither test_docs nor validation_docs")
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
task_docs = list(task_doc_func()) task_docs = list(task_doc_func())
...@@ -88,27 +167,28 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i ...@@ -88,27 +167,28 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
rnd.seed(42) rnd.seed(42)
rnd.shuffle(task_docs) rnd.shuffle(task_docs)
description = description_dict[task_name] if description_dict and task_name in description_dict else ""
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
if decontaminate and task.should_decontaminate(): if decontaminate and task.should_decontaminate():
docs_for_decontamination[(task_name, task_set)].append(task.doc_to_decontamination_query(doc)) docs_for_decontamination[(task_name, task_set)].append(task.doc_to_decontamination_query(doc))
docs[(task_name, doc_id)] = doc docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context( ctx = task.fewshot_context(
doc=doc, doc=doc,
provide_description=provide_description,
num_fewshot=num_fewshot, num_fewshot=num_fewshot,
rnd=rnd rnd=rnd,
description=description
) )
reqs = task.construct_requests(doc, ctx) reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)): reqs = [reqs] if not isinstance(reqs, (list, tuple)):
reqs = [reqs]
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
requests[req.type].append(req) requests[req.request_type].append(req)
# i: index in requests for a single task instance # i: index in requests for a single task instance
# doc_id: unique id that we can get back to a doc using `docs` # doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.type].append((i, task_name, doc, doc_id)) requests_origin[req.request_type].append((i, task_name, doc, doc_id))
# Compare all tasks/sets at once to ensure a single training set scan # Compare all tasks/sets at once to ensure a single training set scan
if decontaminate: if decontaminate:
...@@ -120,13 +200,13 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i ...@@ -120,13 +200,13 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
# execute each type of request # execute each type of request
for reqtype, reqs in requests.items(): for reqtype, reqs in requests.items():
# TODO: right now, this code runs multiple seperate LM requests for multiple Requests differing # TODO: right now, this code runs multiple separate LM requests for multiple Requests differing
# only in index. We could implement some kind of caching, but that would be more of a bandaid # only in index. We could implement some kind of caching, but that would be more of a band-aid
# solution. we could also implement some kind of autogrouping here; they should end up next to each other. # solution. we could also implement some kind of auto-grouping here;
# they should end up next to each other.
print("Running", reqtype, "requests") print("Running", reqtype, "requests")
resps = getattr(lm, reqtype)([req.args for req in reqs]) resps = getattr(lm, reqtype)([req.args for req in reqs])
resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)] resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]): for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
...@@ -161,7 +241,12 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i ...@@ -161,7 +241,12 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
stderr = lm_eval.metrics.stderr_for_metric(task.aggregation()[real_metric], bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters)
stderr = lm_eval.metrics.stderr_for_metric(
metric=task.aggregation()[real_metric],
bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters,
)
if stderr is not None: if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items) results[task_name][metric + "_stderr"] = stderr(items)
...@@ -172,6 +257,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i ...@@ -172,6 +257,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
def make_table(result_dict): def make_table(result_dict):
"""Generate table of results."""
from pytablewriter import MarkdownTableWriter, LatexTableWriter from pytablewriter import MarkdownTableWriter, LatexTableWriter
md_writer = MarkdownTableWriter() md_writer = MarkdownTableWriter()
...@@ -184,11 +270,11 @@ def make_table(result_dict): ...@@ -184,11 +270,11 @@ def make_table(result_dict):
for k, dic in result_dict["results"].items(): for k, dic in result_dict["results"].items():
version = result_dict["versions"][k] version = result_dict["versions"][k]
for m, v in dic.items(): for m, v in dic.items():
if m.endswith("_stderr"): continue if m.endswith("_stderr"):
continue
if m + "_stderr" in dic: if m + "_stderr" in dic:
se = dic[m + "_stderr"] se = dic[m + "_stderr"]
values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se]) values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se])
else: else:
values.append([k, version, m, '%.4f' % v, '', '']) values.append([k, version, m, '%.4f' % v, '', ''])
...@@ -200,4 +286,4 @@ def make_table(result_dict): ...@@ -200,4 +286,4 @@ def make_table(result_dict):
# todo: make latex table look good # todo: make latex table look good
# print(latex_writer.dumps()) # print(latex_writer.dumps())
return md_writer.dumps() return md_writer.dumps()
\ No newline at end of file
import math import math
from collections import Iterable from collections.abc import Iterable
from pprint import pprint
import numpy as np import numpy as np
import sacrebleu import sacrebleu
...@@ -53,16 +52,18 @@ def acc_all(items): ...@@ -53,16 +52,18 @@ def acc_all(items):
docs = list(zip(*items))[1] docs = list(zip(*items))[1]
for doc, pred in zip(docs, preds): for doc, pred in zip(docs, preds):
paragraph_id = doc["idx"]["paragraph"]
question_id = doc["idx"]["question"] question_id = doc["idx"]["question"]
if question_id not in question_scoring_dict: if (paragraph_id, question_id) not in question_scoring_dict:
question_scoring_dict[question_id] = [] question_scoring_dict[(paragraph_id, question_id)] = []
gold_label = doc["label"] == 1 gold_label = doc["label"] == 1
question_scoring_dict[question_id].append(gold_label == pred)
question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()]) acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
return acc return acc
def acc_all_stderr(items): def acc_all_stderr(items):
# Only count as correct if all answers are labeled correctly for each question # Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {} question_scoring_dict = {}
...@@ -98,9 +99,13 @@ def weighted_mean(items): ...@@ -98,9 +99,13 @@ def weighted_mean(items):
a, b = zip(*items) a, b = zip(*items)
return sum(a) / sum(b) return sum(a) / sum(b)
def weighted_perplexity(items): def weighted_perplexity(items):
return math.exp(-weighted_mean(items)) return math.exp(-weighted_mean(items))
def bits_per_byte(items):
return -weighted_mean(items) / math.log(2)
def bleu(items): def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
...@@ -179,12 +184,13 @@ def _sacreformat(refs, preds): ...@@ -179,12 +184,13 @@ def _sacreformat(refs, preds):
return refs, preds return refs, preds
## stderr stuff # stderr stuff
class _bootstrap_internal: class _bootstrap_internal:
def __init__(self, f, n): def __init__(self, f, n):
self.f = f self.f = f
self.n = n self.n = n
def __call__(self, v): def __call__(self, v):
i, xs = v i, xs = v
rnd = random.Random() rnd = random.Random()
...@@ -208,7 +214,9 @@ def bootstrap_stderr(f, xs, iters): ...@@ -208,7 +214,9 @@ def bootstrap_stderr(f, xs, iters):
chunk_size = min(1000, iters) chunk_size = min(1000, iters)
from tqdm import tqdm from tqdm import tqdm
print("bootstrapping for stddev:", f.__name__) print("bootstrapping for stddev:", f.__name__)
for bootstrap in tqdm(pool.imap(_bootstrap_internal(f, chunk_size), [(i, xs) for i in range(iters // chunk_size)]), total=iters // chunk_size): for bootstrap in tqdm(pool.imap(
_bootstrap_internal(f, chunk_size),
[(i, xs) for i in range(iters // chunk_size)]), total=iters // chunk_size):
# sample w replacement # sample w replacement
res.extend(bootstrap) res.extend(bootstrap)
......
...@@ -3,6 +3,7 @@ from . import gpt3 ...@@ -3,6 +3,7 @@ from . import gpt3
from . import dummy from . import dummy
MODEL_REGISTRY = { MODEL_REGISTRY = {
"hf": gpt2.HFLM,
"gpt2": gpt2.GPT2LM, "gpt2": gpt2.GPT2LM,
"gpt3": gpt3.GPT3LM, "gpt3": gpt3.GPT3LM,
"dummy": dummy.DummyLM, "dummy": dummy.DummyLM,
......
import transformers import transformers
import torch import torch
import torch.nn as nn from lm_eval.base import BaseLM
import torch.nn.functional as F
from lm_eval.base import LM
from lm_eval import utils
from tqdm import tqdm
import numpy as np
class GPT2LM(LM): class HFLM(BaseLM):
MAX_GEN_TOKS = 256
def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1): def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1):
super().__init__() super().__init__()
...@@ -29,222 +23,86 @@ class GPT2LM(LM): ...@@ -29,222 +23,86 @@ class GPT2LM(LM):
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# TODO: update this to be less of a hack once subfolder is fixed in HF # TODO: update this to be less of a hack once subfolder is fixed in HF
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained, revision=revision +("/" + subfolder if subfolder is not None else "")).to(self.device) self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, revision=revision + ("/" + subfolder if subfolder is not None else "")
).to(self.device)
self.gpt2.eval() self.gpt2.eval()
# pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2 # pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained if tokenizer is None else tokenizer, revision=revision, subfolder=subfolder) self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, revision=revision, subfolder=subfolder)
assert isinstance(self.tokenizer, ( assert isinstance(self.tokenizer, (
transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast, transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
transformers.T5Tokenizer, transformers.T5TokenizerFast, transformers.T5Tokenizer, transformers.T5TokenizerFast,
)), "this tokenizer has not been checked for compatibility yet!" )), "this tokenizer has not been checked for compatibility yet!"
self.VOCAB_SIZE = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
self.EOT_TOKEN_ID = self.tokenizer.eos_token_id
print(self.EOT_TOKEN_ID)
try: if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)):
self.max_length = self.gpt2.config.n_ctx assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373], \
except AttributeError: self.tokenizer.encode('hello\n\nhello')
# gptneoconfig doesn't have n_ctx apparantly
self.max_length = self.gpt2.config.max_position_embeddings
if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)):
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
# multithreading and batching # multithreading and batching
gpus = torch.cuda.device_count() self.batch_size_per_gpu = batch_size # todo: adaptive batch size
batch_size_per_gpu = batch_size # todo: adaptive batch size
# TODO: fix multi-gpu
self.batch_size = batch_size_per_gpu# * gpus
# TODO: fix multi-gpu # TODO: fix multi-gpu
# gpus = torch.cuda.device_count()
# if gpus > 1: # if gpus > 1:
# self.gpt2 = nn.DataParallel(self.gpt2) # self.gpt2 = nn.DataParallel(self.gpt2)
@classmethod @property
def create_from_arg_string(cls, arg_string, additional_config={}): def eot_token_id(self):
args = utils.simple_parse_args_string(arg_string) # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
args2 = {k: v for k, v in additional_config.items() if v is not None} return self.tokenizer.eos_token_id
return cls(**args, **args2)
def loglikelihood(self, requests):
new_reqs = []
for context, continuation in requests:
if context == "":
# end of text as context
context_enc = [self.EOT_TOKEN_ID]
else:
context_enc = self.tokenizer.encode(context, add_special_tokens=False)
continuation_enc = self.tokenizer.encode(continuation, add_special_tokens=False)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests):
# TODO: Implement caching once we've confirmed the perplexity implementation
# TODO: automatic batch size detection for vectorization
loglikelihoods = []
with torch.no_grad():
for string, in tqdm(requests):
rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
token_list=self.tokenizer.encode(string, add_special_tokens=False),
prefix_token=self.EOT_TOKEN_ID,
max_seq_len=self.max_length,
context_len=1,
)))
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for that
string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True)
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
with torch.no_grad():
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch padded context length.
# this is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = x[1] + x[2]
return (-len(toks), tuple(toks))
# TODO: automatic (variable) batch size detection for vectorization
reord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
inps = []
contlens = []
inplens = []
padding_length = None
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying
for _, context_enc, continuation_enc in chunk:
# sanity check
assert len(context_enc) > 0
assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.VOCAB_SIZE] slice
# cont_toks 4 5 6 7 8 9
# when too long to fit in context, truncate from the left
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length+1):][:-1]
, dtype=torch.long).to(self.device)
inplen, = inp.shape
cont = continuation_enc @property
def max_length(self):
# since in _collate we make sure length is descending, the longest is always the first one. try:
padding_length = padding_length if padding_length is not None else inplen return self.gpt2.config.n_ctx
except AttributeError:
# pad to length # gptneoconfig doesn't have n_ctx apparently
inp = torch.cat([ return self.gpt2.config.max_position_embeddings
inp, # [seq]
torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq]
], dim=0)
inps.append(inp.unsqueeze(0))
contlens.append(cont)
inplens.append(inplen)
multi_logits = F.log_softmax(self._model_call(torch.cat(inps, dim=0)), dim=-1).cpu() # [batch, seq, vocab]
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(chunk, multi_logits, inps, inplens, contlens):
contlen = len(cont_toks)
logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab]
greedy_tokens = logits.argmax(dim=-1)
# cont_toks :: [1, seq]
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0)
max_equal = (greedy_tokens == cont_toks).all()
#last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]
answer = (float(logits.sum()), bool(max_equal)) @property
def max_gen_toks(self):
return 256
# partial caching @property
if cache_key is not None: def batch_size(self):
self.cache_hook.add_partial("loglikelihood", cache_key, answer) # TODO: fix multi-gpu
return self.batch_size_per_gpu # * gpus
res.append(answer) @property
def device(self):
# TODO: fix multi-gpu
return self._device
return reord.get_original(res) def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def _model_call(self, inps): def _model_call(self, inps):
""" """
inps: a torch tensor of shape [batch, sequence] inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the returns: a torch tensor of shape [batch, sequence, vocab] with the
logits retuned from the model logits returned from the model
""" """
return self.gpt2(inps)[0][:, :, :50257] with torch.no_grad():
return self.gpt2(inps)[0][:, :, :50257]
def greedy_until(self, requests): def _model_generate(self, context, max_length, eos_token_id):
# TODO: implement fully general `until` that handles untils that are return self.gpt2.generate(
# multiple tokens or that span multiple tokens correctly context,
res = [] max_length=max_length,
eos_token_id=eos_token_id,
def _collate(x): do_sample=False
toks = self.tokenizer.encode(x[0], add_special_tokens=False) )
return (len(toks), x[0])
reord = utils.Reorderer(requests, _collate) # for backwards compatibility
GPT2LM = HFLM
for context, until in tqdm(reord.get_reordered()):
if isinstance(until, str): until = [until]
context_enc = torch.tensor([self.tokenizer.encode(context, add_special_tokens=False)[self.MAX_GEN_TOKS - self.max_length:]]).to(self.device)
primary_until, = self.tokenizer.encode(until[0], add_special_tokens=False)
cont = self.gpt2.generate(
context_enc,
max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
eos_token_id=primary_until,
do_sample=False
)
s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])
for term in until:
s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s)
return reord.get_original(res)
import os import os
import numpy as np import numpy as np
import transformers import transformers
from lm_eval.base import LM from lm_eval.base import BaseLM
from lm_eval import utils from lm_eval import utils
from tqdm import tqdm from tqdm import tqdm
import time import time
def get_result(response, ctxlen): def get_result(response, ctxlen):
"""Process results from OpenAI API response.
:param response: dict
OpenAI API Response
:param ctxlen: int
Length of context (so we can slice them away and only keep the predictions)
:return:
continuation_logprobs: np.array
Log probabilities of continuation tokens
is_greedy: bool
whether argmax matches given continuation exactly
"""
is_greedy = True is_greedy = True
logprobs = response["logprobs"]["token_logprobs"] logprobs = response["logprobs"]["token_logprobs"]
continuation_logprobs = sum(logprobs[ctxlen:]) continuation_logprobs = sum(logprobs[ctxlen:])
...@@ -24,8 +36,11 @@ def get_result(response, ctxlen): ...@@ -24,8 +36,11 @@ def get_result(response, ctxlen):
def oa_completion(**kwargs): def oa_completion(**kwargs):
import openai """ Query OpenAI API for completion.
Retry with back-off until they respond
"""
import openai
backoff_time = 3 backoff_time = 3
while True: while True:
try: try:
...@@ -35,11 +50,8 @@ def oa_completion(**kwargs): ...@@ -35,11 +50,8 @@ def oa_completion(**kwargs):
backoff_time *= 1.5 backoff_time *= 1.5
class GPT3LM(LM): class GPT3LM(BaseLM):
MAX_LENGTH = 2048
REQ_CHUNK_SIZE = 20 REQ_CHUNK_SIZE = 20
MAX_GEN_TOKS = 256
def __init__(self, engine, truncate=False): def __init__(self, engine, truncate=False):
""" """
...@@ -50,10 +62,12 @@ class GPT3LM(LM): ...@@ -50,10 +62,12 @@ class GPT3LM(LM):
Truncate input if too long (if False and input is too long, throw error) Truncate input if too long (if False and input is too long, throw error)
""" """
super().__init__() super().__init__()
import openai import openai
self.engine = engine self.engine = engine
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
self.vocab_size = self.tokenizer.vocab_size
# to make the annoying "Using pad_token, but it is not set yet." error go away # to make the annoying "Using pad_token, but it is not set yet." error go away
self.tokenizer.pad_token = "<|endoftext|>" self.tokenizer.pad_token = "<|endoftext|>"
...@@ -64,53 +78,36 @@ class GPT3LM(LM): ...@@ -64,53 +78,36 @@ class GPT3LM(LM):
# Read from environment variable OPENAI_API_SECRET_KEY # Read from environment variable OPENAI_API_SECRET_KEY
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"] openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
@classmethod @property
def create_from_arg_string(cls, arg_string, additional_config={}): def eot_token_id(self):
args = utils.simple_parse_args_string(arg_string) return self.tokenizer.eos_token_id
args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(**args, **args2) @property
def max_length(self):
def loglikelihood(self, requests): # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
new_reqs = [] return 2048
for context, continuation in requests:
if context == "": @property
# end of text as context def max_gen_toks(self):
context_enc = [50256] return 256
else:
context_enc = self.tokenizer.encode(context) @property
def batch_size(self):
continuation_enc = self.tokenizer.encode(continuation) # Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
new_reqs.append(((context, continuation), context_enc, continuation_enc))
@property
return self._loglikelihood_tokens(new_reqs) def device(self):
# Isn't used because we override _loglikelihood_tokens
def loglikelihood_rolling(self, requests): raise NotImplementedError()
# TODO: switch implementation to use _loglikelihood_tokens rather than having it do its own thing
def tok_encode(self, string: str):
loglikelihoods = [] return self.tokenizer.encode(string, add_special_tokens=False)
for string, in tqdm(requests):
encoded = self.tokenizer.encode_plus(string)["input_ids"] def tok_decode(self, tokens):
rolling_token_windows = utils.get_rolling_token_windows( return self.tokenizer.decode(tokens)
token_list=encoded,
prefix_token=self.end_of_text_token_id, def _loglikelihood_tokens(self, requests, disable_tqdm=False):
max_seq_len=self.MAX_LENGTH,
context_len=1,
)
string_loglikelihoods = []
for input_tokens, pred_tokens in rolling_token_windows:
block_output = self.get_token_logprobs(
input_tokens=input_tokens,
pred_tokens=pred_tokens,
)
string_loglikelihoods.append(block_output["logprobs"])
string_loglikelihoods = np.concatenate(string_loglikelihoods).sum()
loglikelihoods.append(string_loglikelihoods)
return loglikelihoods
def _loglikelihood_tokens(self, requests):
import openai
res = [] res = []
def _collate(x): def _collate(x):
...@@ -118,16 +115,18 @@ class GPT3LM(LM): ...@@ -118,16 +115,18 @@ class GPT3LM(LM):
# it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
# we care about and so we need some kind of backup for when it isn't # we care about and so we need some kind of backup for when it isn't
toks = x[1] + x[2] toks = x[1] + x[2]
return (-len(toks), tuple(toks)) return -len(toks), tuple(toks)
reord = utils.Reorderer(requests, _collate) reord = utils.Reorderer(requests, _collate)
for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))): for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE)), disable=disable_tqdm):
inps = [] inps = []
ctxlens = [] ctxlens = []
for cache_key, context_enc, continuation_enc in chunk: for cache_key, context_enc, continuation_enc in chunk:
inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:] # max_length+1 because the API takes up to 2049 tokens, including the first context token
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.MAX_LENGTH) inp = (context_enc + continuation_enc)[-(self.max_length+1):]
# TODO: the logic is much simpler if we just look at the length of continuation tokens
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - (self.max_length+1))
inps.append(inp) inps.append(inp)
ctxlens.append(ctxlen) ctxlens.append(ctxlen)
...@@ -151,35 +150,14 @@ class GPT3LM(LM): ...@@ -151,35 +150,14 @@ class GPT3LM(LM):
return reord.get_original(res) return reord.get_original(res)
def get_token_logprobs(self, input_tokens, pred_tokens):
pred_start = len(input_tokens) - len(pred_tokens) + 1
# We're going to stitch together the input_tokens and pred_tokens
# In the longest case, this gets us to length = max_seq_len+1 (which the API works with)
assert input_tokens[pred_start:] == pred_tokens[:-1]
token_ids = input_tokens + [pred_tokens[-1]]
response = oa_completion(
engine=self.engine,
prompt=token_ids,
max_tokens=0,
temperature=0.0,
logprobs=0,
echo=True,
)
logprobs = np.array(response["choices"][0]["logprobs"]["token_logprobs"][pred_start:])
positions = np.arange(pred_start-1, pred_start-1 + len(token_ids[pred_start:]))
return {
"logprobs": logprobs,
"positions": positions,
}
def greedy_until(self, requests): def greedy_until(self, requests):
if not requests: return [] if not requests:
import openai return []
res = [] res = []
def _collate(x): def _collate(x):
toks = self.tokenizer.encode(x[0]) toks = self.tok_encode(x[0])
return (len(toks), x[0]) return len(toks), x[0]
reord = utils.Reorderer(requests, _collate) reord = utils.Reorderer(requests, _collate)
...@@ -193,34 +171,43 @@ class GPT3LM(LM): ...@@ -193,34 +171,43 @@ class GPT3LM(LM):
lastuntil = x[1] lastuntil = x[1]
ret.append(x) ret.append(x)
if ret: yield ret, lastuntil if ret:
yield ret, lastuntil
# todo: more intelligent batching for heterogenous `until` # todo: more intelligent batching for heterogeneous `until`
for chunk, until in tqdm(list(sameuntil_chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))): for chunk, until in tqdm(list(sameuntil_chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
inps = [] inps = []
for context, _ in chunk: for context, _ in chunk:
context_enc = self.tokenizer.encode(context) context_enc = self.tok_encode(context)
inp = context_enc[-(self.MAX_LENGTH - self.MAX_GEN_TOKS):] inp = context_enc[-(self.max_length - self.max_gen_toks):]
inps.append(inp) inps.append(inp)
response = oa_completion( response = oa_completion(
engine=self.engine, engine=self.engine,
prompt=inps, prompt=inps,
max_tokens=self.MAX_GEN_TOKS, max_tokens=self.max_gen_toks,
temperature=0., temperature=0.,
logprobs=10, logprobs=10,
stop=until stop=until,
) )
for resp, (context, until) in zip(response.choices, chunk): for resp, (context, until_) in zip(response.choices, chunk):
s = resp['text'] s = resp['text']
for term in until: for term in until_:
s = s.split(term)[0] s = s.split(term)[0]
# partial caching # partial caching
self.cache_hook.add_partial("greedy_until", (context, until), s) self.cache_hook.add_partial("greedy_until", (context, until_), s)
res.append(s) res.append(s)
return reord.get_original(res) return reord.get_original(res)
def _model_call(self, inps):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
raise NotImplementedError()
from pprint import pprint from pprint import pprint
from typing import List, Union
import sacrebleu import sacrebleu
import lm_eval.base
from . import superglue from . import superglue
from . import glue from . import glue
...@@ -44,6 +46,8 @@ from . import wikitext ...@@ -44,6 +46,8 @@ from . import wikitext
from . import lambada_multilingual from . import lambada_multilingual
from . import mutual from . import mutual
from . import truthfulqa from . import truthfulqa
from . import blimp
from . import asdiv
######################################## ########################################
# Translation tasks # Translation tasks
...@@ -132,7 +136,9 @@ TASK_REGISTRY = { ...@@ -132,7 +136,9 @@ TASK_REGISTRY = {
"squad2": squad.SQuAD2, "squad2": squad.SQuAD2,
"race": race.RACE, "race": race.RACE,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet # "naturalqs": naturalqs.NaturalQs, # not implemented yet
"headqa": headqa.HeadQA, "headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es
"headqa_es": headqa.HeadQAEs,
"headqa_en": headqa.HeadQAEn,
"mathqa": mathqa.MathQA, "mathqa": mathqa.MathQA,
"webqs": webqs.WebQs, "webqs": webqs.WebQs,
"wsc273": wsc273.WinogradSchemaChallenge273, "wsc273": wsc273.WinogradSchemaChallenge273,
...@@ -163,6 +169,7 @@ TASK_REGISTRY = { ...@@ -163,6 +169,7 @@ TASK_REGISTRY = {
"math_num_theory": hendrycks_math.MathNumberTheory, "math_num_theory": hendrycks_math.MathNumberTheory,
"math_prealgebra": hendrycks_math.MathPrealgebra, "math_prealgebra": hendrycks_math.MathPrealgebra,
"math_precalc": hendrycks_math.MathPrecalculus, "math_precalc": hendrycks_math.MathPrecalculus,
"math_asdiv": asdiv.Asdiv,
# arithmetic # arithmetic
"arithmetic_2da": arithmetic.Arithmetic2DPlus, "arithmetic_2da": arithmetic.Arithmetic2DPlus,
...@@ -217,6 +224,74 @@ TASK_REGISTRY = { ...@@ -217,6 +224,74 @@ TASK_REGISTRY = {
"pile_wikipedia": pile.PileWikipedia, "pile_wikipedia": pile.PileWikipedia,
"pile_youtubesubtitles": pile.PileYoutubeSubtitles, "pile_youtubesubtitles": pile.PileYoutubeSubtitles,
# BLiMP
"blimp_adjunct_island": blimp.BlimpAdjunctIsland,
"blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement,
"blimp_anaphor_number_agreement": blimp.BlimpAnaphorNumberAgreement,
"blimp_animate_subject_passive": blimp.BlimpAnimateSubjectPassive,
"blimp_animate_subject_trans": blimp.BlimpAnimateSubjectTrans,
"blimp_causative": blimp.BlimpCausative,
"blimp_complex_NP_island": blimp.BlimpComplex_NPIsland,
"blimp_coordinate_structure_constraint_complex_left_branch": blimp.BlimpCoordinateStructureConstraintComplexLeftBranch,
"blimp_coordinate_structure_constraint_object_extraction": blimp.BlimpCoordinateStructureConstraintObjectExtraction,
"blimp_determiner_noun_agreement_1": blimp.BlimpDeterminerNounAgreement_1,
"blimp_determiner_noun_agreement_2": blimp.BlimpDeterminerNounAgreement_2,
"blimp_determiner_noun_agreement_irregular_1": blimp.BlimpDeterminerNounAgreementIrregular_1,
"blimp_determiner_noun_agreement_irregular_2": blimp.BlimpDeterminerNounAgreementIrregular_2,
"blimp_determiner_noun_agreement_with_adj_2": blimp.BlimpDeterminerNounAgreementWithAdj_2,
"blimp_determiner_noun_agreement_with_adj_irregular_1": blimp.BlimpDeterminerNounAgreementWithAdjIrregular_1,
"blimp_determiner_noun_agreement_with_adj_irregular_2": blimp.BlimpDeterminerNounAgreementWithAdjIrregular_2,
"blimp_determiner_noun_agreement_with_adjective_1": blimp.BlimpDeterminerNounAgreementWithAdjective_1,
"blimp_distractor_agreement_relational_noun": blimp.BlimpDistractorAgreementRelationalNoun,
"blimp_distractor_agreement_relative_clause": blimp.BlimpDistractorAgreementRelativeClause,
"blimp_drop_argument": blimp.BlimpDropArgument,
"blimp_ellipsis_n_bar_1": blimp.BlimpEllipsisNBar_1,
"blimp_ellipsis_n_bar_2": blimp.BlimpEllipsisNBar_2,
"blimp_existential_there_object_raising": blimp.BlimpExistentialThereObjectRaising,
"blimp_existential_there_quantifiers_1": blimp.BlimpExistentialThereQuantifiers_1,
"blimp_existential_there_quantifiers_2": blimp.BlimpExistentialThereQuantifiers_2,
"blimp_existential_there_subject_raising": blimp.BlimpExistentialThereSubjectRaising,
"blimp_expletive_it_object_raising": blimp.BlimpExpletiveItObjectRaising,
"blimp_inchoative": blimp.BlimpInchoative,
"blimp_intransitive": blimp.BlimpIntransitive,
"blimp_irregular_past_participle_adjectives": blimp.BlimpIrregularPastParticipleAdjectives,
"blimp_irregular_past_participle_verbs": blimp.BlimpIrregularPastParticipleVerbs,
"blimp_irregular_plural_subject_verb_agreement_1": blimp.BlimpIrregularPluralSubjectVerbAgreement_1,
"blimp_irregular_plural_subject_verb_agreement_2": blimp.BlimpIrregularPluralSubjectVerbAgreement_2,
"blimp_left_branch_island_echo_question": blimp.BlimpLeftBranchIslandEchoQuestion,
"blimp_left_branch_island_simple_question": blimp.BlimpLeftBranchIslandSimpleQuestion,
"blimp_matrix_question_npi_licensor_present": blimp.BlimpMatrixQuestionNpiLicensorPresent,
"blimp_npi_present_1": blimp.BlimpNpiPresent_1,
"blimp_npi_present_2": blimp.BlimpNpiPresent_2,
"blimp_only_npi_licensor_present": blimp.BlimpOnlyNpiLicensorPresent,
"blimp_only_npi_scope": blimp.BlimpOnlyNpiScope,
"blimp_passive_1": blimp.BlimpPassive_1,
"blimp_passive_2": blimp.BlimpPassive_2,
"blimp_principle_A_c_command": blimp.BlimpPrinciple_ACCommand,
"blimp_principle_A_case_1": blimp.BlimpPrinciple_ACase_1,
"blimp_principle_A_case_2": blimp.BlimpPrinciple_ACase_2,
"blimp_principle_A_domain_1": blimp.BlimpPrinciple_ADomain_1,
"blimp_principle_A_domain_2": blimp.BlimpPrinciple_ADomain_2,
"blimp_principle_A_domain_3": blimp.BlimpPrinciple_ADomain_3,
"blimp_principle_A_reconstruction": blimp.BlimpPrinciple_AReconstruction,
"blimp_regular_plural_subject_verb_agreement_1": blimp.BlimpRegularPluralSubjectVerbAgreement_1,
"blimp_regular_plural_subject_verb_agreement_2": blimp.BlimpRegularPluralSubjectVerbAgreement_2,
"blimp_sentential_negation_npi_licensor_present": blimp.BlimpSententialNegationNpiLicensorPresent,
"blimp_sentential_negation_npi_scope": blimp.BlimpSententialNegationNpiScope,
"blimp_sentential_subject_island": blimp.BlimpSententialSubjectIsland,
"blimp_superlative_quantifiers_1": blimp.BlimpSuperlativeQuantifiers_1,
"blimp_superlative_quantifiers_2": blimp.BlimpSuperlativeQuantifiers_2,
"blimp_tough_vs_raising_1": blimp.BlimpToughVsRaising_1,
"blimp_tough_vs_raising_2": blimp.BlimpToughVsRaising_2,
"blimp_transitive": blimp.BlimpTransitive,
"blimp_wh_island": blimp.BlimpWhIsland,
"blimp_wh_questions_object_gap": blimp.BlimpWhQuestionsObjectGap,
"blimp_wh_questions_subject_gap": blimp.BlimpWhQuestionsSubjectGap,
"blimp_wh_questions_subject_gap_long_distance": blimp.BlimpWhQuestionsSubjectGapLongDistance,
"blimp_wh_vs_that_no_gap": blimp.BlimpWhVsThatNoGap,
"blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
"blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
"blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
} }
...@@ -232,8 +307,23 @@ def get_task(task_name): ...@@ -232,8 +307,23 @@ def get_task(task_name):
raise KeyError(f"Missing task {task_name}") raise KeyError(f"Missing task {task_name}")
def get_task_dict(task_name_list): def get_task_name_from_object(task_object):
return { for name, class_ in TASK_REGISTRY.items():
if class_ is task_object:
return name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return task_object.EVAL_HARNESS_NAME if hasattr(task_object, "EVAL_HARNESS_NAME") else type(task_object).__name__
def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]):
task_name_dict = {
task_name: get_task(task_name)() task_name: get_task(task_name)()
for task_name in task_name_list for task_name in task_name_list if isinstance(task_name, str)
}
task_name_from_object_dict = {
get_task_name_from_object(task_object): task_object
for task_object in task_name_list if not isinstance(task_object, str)
} }
assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {**task_name_dict, **task_name_from_object_dict}
...@@ -33,10 +33,6 @@ class ANLIBase(HFTask): ...@@ -33,10 +33,6 @@ class ANLIBase(HFTask):
if self.has_test_docs(): if self.has_test_docs():
return self.data["test_r" + str(self.SPLIT)] return self.data["test_r" + str(self.SPLIT)]
def fewshot_description(self):
# TODO: figure out description
return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
# OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning # OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning
# of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly # of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly
......
...@@ -29,10 +29,6 @@ class ARCEasy(HFTask, MultipleChoiceTask): ...@@ -29,10 +29,6 @@ class ARCEasy(HFTask, MultipleChoiceTask):
} }
return out_doc return out_doc
def fewshot_description(self):
# TODO: figure out description
return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
......
...@@ -21,7 +21,7 @@ class Arithmetic(Task): ...@@ -21,7 +21,7 @@ class Arithmetic(Task):
url = 'https://raw.githubusercontent.com/openai/gpt-3/master/data/' + file_name url = 'https://raw.githubusercontent.com/openai/gpt-3/master/data/' + file_name
if not os.path.exists(self.directory): if not os.path.exists(self.directory):
os.makedirs(self.directory) os.makedirs(self.directory)
download_file(url, self.directory+file_name, checksum) download_file(url, local_file=self.directory+file_name, expected_checksum=checksum)
self.set_docs() self.set_docs()
@abc.abstractmethod @abc.abstractmethod
......
"""
ASDiv: A Diverse Corpus for Evaluating and Developing English Math Word Problem Solvers
https://arxiv.org/abs/2106.15772
@misc{miao2021diverse,
title={A Diverse Corpus for Evaluating and Developing English Math Word Problem Solvers},
author={Shen-Yun Miao and Chao-Chun Liang and Keh-Yih Su},
year={2021},
eprint={2106.15772},
archivePrefix={arXiv},
primaryClass={cs.AI}
}
"""
from lm_eval.base import Task
from pathlib import Path
from best_download import download_file
import xml.etree.ElementTree as ET
from lm_eval.base import rf
from lm_eval.metrics import mean,perplexity
import numpy as np
from zipfile import ZipFile
import os
#currently ignoring formula for answer generation
# given a subset, splits return the docs
class Asdiv(Task):
VERSION = 0
DATASET_PATH = Path("data/asdiv")
def download(self):
if self.DATASET_PATH.exists():
return
Path.mkdir(self.DATASET_PATH, parents=True)
url = "https://github.com/chaochun/nlu-asdiv-dataset/archive/55790e5270bb91ccfa5053194b25732534696b50.zip"
checksum = "8f1fe4f6d5f170ec1e24ab78c244153c14c568b1bb2b1dad0324e71f37939a2d"
zip_path = self.DATASET_PATH / "55790e5270bb91ccfa5053194b25732534696b50.zip"
download_file(url, local_file=str(zip_path), expected_checksum=checksum)
with ZipFile(zip_path, "r") as zip:
zip.extractall(self.DATASET_PATH)
os.remove(zip_path)
def _convert_standard(self, problem):
#TODO: include solution-type and formula
out_doc = {
"question" : problem.find('Question').text,
"body" : problem.find('Body').text,
"answer": problem.find('Answer').text
}
return out_doc
def load_docs(self, textfilename, tfds=False):
tree = ET.parse(textfilename)
root = tree.getroot()
for pid, problem in enumerate(root.iter('Problem')):
out_doc = self._convert_standard(problem)
yield out_doc
def has_training_docs(self):
return False
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
raise NotImplementedError("This dataset has no training docs")
def test_docs(self):
raise NotImplementedError("This dataset has no test docs")
def validation_docs(self):
data_xml_path = self.DATASET_PATH / "nlu-asdiv-dataset-55790e5270bb91ccfa5053194b25732534696b50/dataset/ASDiv.xml"
return self.load_docs(data_xml_path)
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, "ASDiv is intended only for the zero-shot setting."
return super().fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
rnd=rnd,
description=description
)
def fewshot_description(self):
# TODO: add solution-type and formula
desc = "information containing the context of the question\nQuestion: Text of a question.\nAnswer: Answer to the question, based on the passage.\n"
return desc
def doc_to_text(self, doc):
# TODO: add solution-type
return doc['body'] + '\n' + 'Question:' + doc['question'] + '\n' + 'Answer:'
def doc_to_target(self, doc):
# TODO: add formula
answer = doc['answer'].split(' (')[0]
return " " + answer
def construct_requests(self, doc, ctx):
ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc))
return ll, is_greedy
def process_results(self, doc, results):
ll, is_greedy = results
return {
'acc': int(is_greedy)
}
def aggregation(self):
return {
'acc': mean
}
def higher_is_better(self):
return {
'acc': True
}
"""
BLiMP: A Benchmark of Linguistic Minimal Pairs for English
https://arxiv.org/abs/1912.00582
@article{warstadt2019blimp,
title={BLiMP: A Benchmark of Linguistic Minimal Pairs for English},
author={Warstadt, Alex and Parrish, Alicia and Liu, Haokun and Mohananey, Anhad and Peng, Wei, and Wang, Sheng-Fu and Bowman, Samuel R},
journal={arXiv preprint arXiv:1912.00582},
year={2019}
}
"""
from lm_eval.base import rf
from lm_eval.metrics import mean
from .common import HFTask
class BlimpTask(HFTask):
VERSION = 0
DATASET_PATH = "blimp"
def download(self):
super().download()
# The HF dataset only contains a "train" dataset, but the harness expects a "validation"
# dataset. Let's use the training dataset, on the assumption that the model wasn't actually
# trained on this data.
self.data["validation"] = self.data["train"]
del self.data["train"]
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if provide_description is not None:
# nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
return ""
def doc_to_text(self, doc):
# this method is invoked by tests only
return ""
def doc_to_target(self, doc):
# this method is invoked by tests only
return ""
def construct_requests(self, doc, ctx):
assert not ctx
# Calculate the loglikelihood for the good and the bad sentence.
# Note that loglikelihood translates the "" prefix to the "<|endoftext|>" token
return [
rf.loglikelihood("", doc["sentence_good"]),
rf.loglikelihood("", doc["sentence_bad"]),
]
def process_results(self, doc, results):
likelihood1, likelihood2 = results
# the model got this case right iff the good sentence scored higher than the bad sentence
acc = 1.0 if likelihood1 > likelihood2 else 0.0
return {
"acc": acc,
}
def higher_is_better(self):
return {
"acc": True,
}
def aggregation(self):
return {
"acc": mean,
}
class BlimpAdjunctIsland(BlimpTask):
DATASET_NAME = "adjunct_island"
class BlimpAnaphorGenderAgreement(BlimpTask):
DATASET_NAME = "anaphor_gender_agreement"
class BlimpAnaphorNumberAgreement(BlimpTask):
DATASET_NAME = "anaphor_number_agreement"
class BlimpAnimateSubjectPassive(BlimpTask):
DATASET_NAME = "animate_subject_passive"
class BlimpAnimateSubjectTrans(BlimpTask):
DATASET_NAME = "animate_subject_trans"
class BlimpCausative(BlimpTask):
DATASET_NAME = "causative"
class BlimpComplex_NPIsland(BlimpTask):
DATASET_NAME = "complex_NP_island"
class BlimpCoordinateStructureConstraintComplexLeftBranch(BlimpTask):
DATASET_NAME = "coordinate_structure_constraint_complex_left_branch"
class BlimpCoordinateStructureConstraintObjectExtraction(BlimpTask):
DATASET_NAME = "coordinate_structure_constraint_object_extraction"
class BlimpDeterminerNounAgreement_1(BlimpTask):
DATASET_NAME = "determiner_noun_agreement_1"
class BlimpDeterminerNounAgreement_2(BlimpTask):
DATASET_NAME = "determiner_noun_agreement_2"
class BlimpDeterminerNounAgreementIrregular_1(BlimpTask):
DATASET_NAME = "determiner_noun_agreement_irregular_1"
class BlimpDeterminerNounAgreementIrregular_2(BlimpTask):
DATASET_NAME = "determiner_noun_agreement_irregular_2"
class BlimpDeterminerNounAgreementWithAdj_2(BlimpTask):
DATASET_NAME = "determiner_noun_agreement_with_adj_2"
class BlimpDeterminerNounAgreementWithAdjIrregular_1(BlimpTask):
DATASET_NAME = "determiner_noun_agreement_with_adj_irregular_1"
class BlimpDeterminerNounAgreementWithAdjIrregular_2(BlimpTask):
DATASET_NAME = "determiner_noun_agreement_with_adj_irregular_2"
class BlimpDeterminerNounAgreementWithAdjective_1(BlimpTask):
DATASET_NAME = "determiner_noun_agreement_with_adjective_1"
class BlimpDistractorAgreementRelationalNoun(BlimpTask):
DATASET_NAME = "distractor_agreement_relational_noun"
class BlimpDistractorAgreementRelativeClause(BlimpTask):
DATASET_NAME = "distractor_agreement_relative_clause"
class BlimpDropArgument(BlimpTask):
DATASET_NAME = "drop_argument"
class BlimpEllipsisNBar_1(BlimpTask):
DATASET_NAME = "ellipsis_n_bar_1"
class BlimpEllipsisNBar_2(BlimpTask):
DATASET_NAME = "ellipsis_n_bar_2"
class BlimpExistentialThereObjectRaising(BlimpTask):
DATASET_NAME = "existential_there_object_raising"
class BlimpExistentialThereQuantifiers_1(BlimpTask):
DATASET_NAME = "existential_there_quantifiers_1"
class BlimpExistentialThereQuantifiers_2(BlimpTask):
DATASET_NAME = "existential_there_quantifiers_2"
class BlimpExistentialThereSubjectRaising(BlimpTask):
DATASET_NAME = "existential_there_subject_raising"
class BlimpExpletiveItObjectRaising(BlimpTask):
DATASET_NAME = "expletive_it_object_raising"
class BlimpInchoative(BlimpTask):
DATASET_NAME = "inchoative"
class BlimpIntransitive(BlimpTask):
DATASET_NAME = "intransitive"
class BlimpIrregularPastParticipleAdjectives(BlimpTask):
DATASET_NAME = "irregular_past_participle_adjectives"
class BlimpIrregularPastParticipleVerbs(BlimpTask):
DATASET_NAME = "irregular_past_participle_verbs"
class BlimpIrregularPluralSubjectVerbAgreement_1(BlimpTask):
DATASET_NAME = "irregular_plural_subject_verb_agreement_1"
class BlimpIrregularPluralSubjectVerbAgreement_2(BlimpTask):
DATASET_NAME = "irregular_plural_subject_verb_agreement_2"
class BlimpLeftBranchIslandEchoQuestion(BlimpTask):
DATASET_NAME = "left_branch_island_echo_question"
class BlimpLeftBranchIslandSimpleQuestion(BlimpTask):
DATASET_NAME = "left_branch_island_simple_question"
class BlimpMatrixQuestionNpiLicensorPresent(BlimpTask):
DATASET_NAME = "matrix_question_npi_licensor_present"
class BlimpNpiPresent_1(BlimpTask):
DATASET_NAME = "npi_present_1"
class BlimpNpiPresent_2(BlimpTask):
DATASET_NAME = "npi_present_2"
class BlimpOnlyNpiLicensorPresent(BlimpTask):
DATASET_NAME = "only_npi_licensor_present"
class BlimpOnlyNpiScope(BlimpTask):
DATASET_NAME = "only_npi_scope"
class BlimpPassive_1(BlimpTask):
DATASET_NAME = "passive_1"
class BlimpPassive_2(BlimpTask):
DATASET_NAME = "passive_2"
class BlimpPrinciple_ACCommand(BlimpTask):
DATASET_NAME = "principle_A_c_command"
class BlimpPrinciple_ACase_1(BlimpTask):
DATASET_NAME = "principle_A_case_1"
class BlimpPrinciple_ACase_2(BlimpTask):
DATASET_NAME = "principle_A_case_2"
class BlimpPrinciple_ADomain_1(BlimpTask):
DATASET_NAME = "principle_A_domain_1"
class BlimpPrinciple_ADomain_2(BlimpTask):
DATASET_NAME = "principle_A_domain_2"
class BlimpPrinciple_ADomain_3(BlimpTask):
DATASET_NAME = "principle_A_domain_3"
class BlimpPrinciple_AReconstruction(BlimpTask):
DATASET_NAME = "principle_A_reconstruction"
class BlimpRegularPluralSubjectVerbAgreement_1(BlimpTask):
DATASET_NAME = "regular_plural_subject_verb_agreement_1"
class BlimpRegularPluralSubjectVerbAgreement_2(BlimpTask):
DATASET_NAME = "regular_plural_subject_verb_agreement_2"
class BlimpSententialNegationNpiLicensorPresent(BlimpTask):
DATASET_NAME = "sentential_negation_npi_licensor_present"
class BlimpSententialNegationNpiScope(BlimpTask):
DATASET_NAME = "sentential_negation_npi_scope"
class BlimpSententialSubjectIsland(BlimpTask):
DATASET_NAME = "sentential_subject_island"
class BlimpSuperlativeQuantifiers_1(BlimpTask):
DATASET_NAME = "superlative_quantifiers_1"
class BlimpSuperlativeQuantifiers_2(BlimpTask):
DATASET_NAME = "superlative_quantifiers_2"
class BlimpToughVsRaising_1(BlimpTask):
DATASET_NAME = "tough_vs_raising_1"
class BlimpToughVsRaising_2(BlimpTask):
DATASET_NAME = "tough_vs_raising_2"
class BlimpTransitive(BlimpTask):
DATASET_NAME = "transitive"
class BlimpWhIsland(BlimpTask):
DATASET_NAME = "wh_island"
class BlimpWhQuestionsObjectGap(BlimpTask):
DATASET_NAME = "wh_questions_object_gap"
class BlimpWhQuestionsSubjectGap(BlimpTask):
DATASET_NAME = "wh_questions_subject_gap"
class BlimpWhQuestionsSubjectGapLongDistance(BlimpTask):
DATASET_NAME = "wh_questions_subject_gap_long_distance"
class BlimpWhVsThatNoGap(BlimpTask):
DATASET_NAME = "wh_vs_that_no_gap"
class BlimpWhVsThatNoGapLongDistance(BlimpTask):
DATASET_NAME = "wh_vs_that_no_gap_long_distance"
class BlimpWhVsThatWithGap(BlimpTask):
DATASET_NAME = "wh_vs_that_with_gap"
class BlimpWhVsThatWithGapLongDistance(BlimpTask):
DATASET_NAME = "wh_vs_that_with_gap_long_distance"
...@@ -17,10 +17,6 @@ class CBTBase(HFTask): ...@@ -17,10 +17,6 @@ class CBTBase(HFTask):
VERSION = 0 VERSION = 0
def fewshot_description(self):
# TODO: Figure out description.
return ""
def detokenize(self, text): def detokenize(self, text):
text = text.replace(" '", "'") text = text.replace(" '", "'")
text = text.replace(" \n", "\n") text = text.replace(" \n", "\n")
......
...@@ -16,8 +16,8 @@ class CoQA(Task): ...@@ -16,8 +16,8 @@ class CoQA(Task):
sh ("""mkdir -p data/coqa""") sh ("""mkdir -p data/coqa""")
download_file("http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json", coqa_train_filepath, "b0fdb2bc1bd38dd3ca2ce5fa2ac3e02c6288ac914f241ac409a655ffb6619fa6") download_file("http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json", local_file=coqa_train_filepath, expected_checksum="b0fdb2bc1bd38dd3ca2ce5fa2ac3e02c6288ac914f241ac409a655ffb6619fa6")
download_file("http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json", coqa_dev_filepath, "dfa367a9733ce53222918d0231d9b3bedc2b8ee831a2845f62dfc70701f2540a") download_file("http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json", local_file=coqa_dev_filepath, expected_checksum="dfa367a9733ce53222918d0231d9b3bedc2b8ee831a2845f62dfc70701f2540a")
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -36,10 +36,7 @@ class CoQA(Task): ...@@ -36,10 +36,7 @@ class CoQA(Task):
def test_docs(self): def test_docs(self):
pass pass
def fewshot_description(self):
return "Given a passage and a conversation so far, answer the next question in the conversation."
def doc_to_text(self, doc): def doc_to_text(self, doc):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1} # Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai # and a question qi, the task is to predict the answer ai
......
...@@ -27,7 +27,7 @@ class DROP(Task): ...@@ -27,7 +27,7 @@ class DROP(Task):
url = "https://s3-us-west-2.amazonaws.com/allennlp/datasets/drop/drop_dataset.zip" url = "https://s3-us-west-2.amazonaws.com/allennlp/datasets/drop/drop_dataset.zip"
checksum = "39d2278a29fd729de301b111a45f434c24834f40df8f4ff116d864589e3249d6" checksum = "39d2278a29fd729de301b111a45f434c24834f40df8f4ff116d864589e3249d6"
zip_path = self.DATASET_PATH / "drop_dataset.zip" zip_path = self.DATASET_PATH / "drop_dataset.zip"
download_file(url, str(zip_path), checksum) download_file(url, local_file=str(zip_path), expected_checksum=checksum)
with ZipFile(zip_path, "r") as zip: with ZipFile(zip_path, "r") as zip:
zip.extractall(self.DATASET_PATH) zip.extractall(self.DATASET_PATH)
...@@ -40,10 +40,6 @@ class DROP(Task): ...@@ -40,10 +40,6 @@ class DROP(Task):
def has_test_docs(self): def has_test_docs(self):
return False return False
def fewshot_description(self):
# TODO: figure out description
return ""
def _load_docs(self, docs): def _load_docs(self, docs):
for doc in docs: for doc in docs:
for qa in doc["qa_pairs"]: for qa in doc["qa_pairs"]:
......
...@@ -21,10 +21,6 @@ class CoLA(HFTask): ...@@ -21,10 +21,6 @@ class CoLA(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def fewshot_description(self):
# TODO
return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(doc["sentence"]) return "{}\nQuestion: Does this sentence make sense?\nAnswer:".format(doc["sentence"])
...@@ -69,9 +65,6 @@ class SST(HFTask): ...@@ -69,9 +65,6 @@ class SST(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def fewshot_description(self):
return "Indicate if the sentiment of each sentence is positive or negative."
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: Is this sentence positive or negative?\nAnswer:".format( return "{}\nQuestion: Is this sentence positive or negative?\nAnswer:".format(
general_detokenize(doc["sentence"]), general_detokenize(doc["sentence"]),
...@@ -227,7 +220,7 @@ class QNLI(HFTask): ...@@ -227,7 +220,7 @@ class QNLI(HFTask):
class WNLI(HFTask): class WNLI(HFTask):
VERSION = 0 VERSION = 1
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "wnli" DATASET_NAME = "wnli"
...@@ -241,26 +234,25 @@ class WNLI(HFTask): ...@@ -241,26 +234,25 @@ class WNLI(HFTask):
return False return False
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format( return "{}\nQuestion: {} True or False?\nAnswer:".format(
doc["sentence1"], doc["sentence1"],
doc["sentence2"], doc["sentence2"],
) )
def doc_to_target(self, doc): def doc_to_target(self, doc):
# True = entailment # True = entailment
# False = contradiction # False = not_entailment
# Neither = neutral return " {}".format({0: "False", 1: "True"}[doc["label"]])
return " {}".format({0: "True", 1: "Neither", 2: "False"}[doc["label"]])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll_true, _ = rf.loglikelihood(ctx, " True") ll_true, _ = rf.loglikelihood(ctx, " True")
ll_neither, _ = rf.loglikelihood(ctx, " Neither")
ll_false, _ = rf.loglikelihood(ctx, " False") ll_false, _ = rf.loglikelihood(ctx, " False")
return ll_true, ll_neither, ll_false return ll_true, ll_false
def process_results(self, doc, results): def process_results(self, doc, results):
ll_true, ll_false = results
pred = ll_true > ll_false
gold = doc["label"] gold = doc["label"]
pred = np.argmax(results)
return { return {
"acc": pred == gold "acc": pred == gold
} }
...@@ -342,9 +334,6 @@ class MRPC(HFTask): ...@@ -342,9 +334,6 @@ class MRPC(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def fewshot_description(self):
return "Indicate if both sentences mean the same thing."
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Sentence 1: {}\nSentence 2: {}\nQuestion: Do both sentences mean the same thing?\nAnswer:".format( return "Sentence 1: {}\nSentence 2: {}\nQuestion: Do both sentences mean the same thing?\nAnswer:".format(
general_detokenize(doc["sentence1"]), general_detokenize(doc["sentence1"]),
...@@ -395,9 +384,6 @@ class QQP(HFTask): ...@@ -395,9 +384,6 @@ class QQP(HFTask):
def has_test_docs(self): def has_test_docs(self):
return False return False
def fewshot_description(self):
return "Indicate if both questions ask the same thing."
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "Question 1: {}\nQuestion 2: {}\nQuestion: Do both questions ask the same thing?\nAnswer:".format( return "Question 1: {}\nQuestion 2: {}\nQuestion: Do both questions ask the same thing?\nAnswer:".format(
doc["question1"], doc["question1"],
...@@ -448,10 +434,6 @@ class STSB(HFTask): ...@@ -448,10 +434,6 @@ class STSB(HFTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def fewshot_description(self):
return "Indicate if both sentences mean the same thing from a scale of 0-5, " \
"where 5 means identical and 0 means unrelated."
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "sentence 1: {}\nsentence 2: {}\nAnswer:".format( return "sentence 1: {}\nsentence 2: {}\nAnswer:".format(
doc["sentence1"], doc["sentence1"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment