Unverified Commit 170ae096 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #226 from jon-tow/evaluator-description-option

Replace `fewshot_description` API with a `description_dict` based interface
parents 8728710c 02a4def2
......@@ -85,9 +85,14 @@ class TruthfulQAMultipleChoice(Task):
def doc_to_target(self, doc):
return " "
def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting."
return super().fewshot_context(doc, num_fewshot, provide_description, rnd)
return super().fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
rnd=rnd,
description=description
)
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
......@@ -217,9 +222,14 @@ class TruthfulQAGeneration(Task):
def doc_to_target(self, doc):
return " "
def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting."
return super().fewshot_context(doc, num_fewshot, provide_description, rnd)
return super().fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
rnd=rnd,
description=description
)
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
......
......@@ -45,9 +45,6 @@ class WordUnscrambleTask(Task):
file = self.BASE_PATH / self.FILENAME
return (json.loads(line) for line in open(file).read().splitlines())
def fewshot_description(self):
return "Please unscramble the letters into a word, and write that word:"
def doc_to_text(self, doc):
return doc["context"]
......
......@@ -17,10 +17,6 @@ class WebQs(HFTask):
def has_test_docs(self):
return True
def fewshot_description(self):
# TODO: figure out description
return ""
def doc_to_text(self, doc):
return "Question: " + doc['question'] + '\nAnswer:'
......@@ -40,7 +36,6 @@ class WebQs(HFTask):
ret.append(alias)
return ret
def construct_requests(self, doc, ctx):
ret = []
......@@ -62,4 +57,4 @@ class WebQs(HFTask):
def higher_is_better(self):
return {
"acc": True
}
\ No newline at end of file
}
......@@ -49,10 +49,6 @@ class WikiText(PerplexityTask):
download_file("https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip", "data/wikitext/wikitext-2-raw-v1.zip", "ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11")
sh("cd data/wikitext/ && unzip wikitext-2-raw-v1.zip")
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def has_validation_docs(self):
return True
......
......@@ -29,10 +29,6 @@ class Winogrande(HFTask):
def doc_to_text(self, doc):
return self.partial_context(doc, doc["option" + doc["answer"]])
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in."
@classmethod
def partial_context(cls, doc, option):
# Substitute the pronoun in the sentence with the specified option
......
......@@ -53,10 +53,6 @@ class WinogradSchemaChallenge273(HFTask):
def has_test_docs(self):
return True
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence with correct continuation. True. Winograd schema sentence with incorrect continuation. False."
def fewshot_examples(self, k, rnd):
# NOTE: `super().fewshot_examples` samples from training docs which are
# not available for this test-set-only dataset.
......
import os
import re
import collections
import functools
import inspect
class ExitCodeError(Exception):
......@@ -138,4 +140,18 @@ class Reorderer:
assert all(cov)
return res
\ No newline at end of file
return res
def positional_deprecated(fn):
"""
A decorator to nudge users into passing only keyword args (`kwargs`) to the
wrapped function, `fn`.
"""
@functools.wraps(fn)
def _wrapper(*args, **kwargs):
if len(args) != 1 if inspect.ismethod(fn) else 0:
print(f"WARNING: using {fn.__name__} with positional arguments is "
"deprecated and will be disallowed in a future version of "
"lm-evaluation-harness!")
return fn(*args, **kwargs)
return _wrapper
......@@ -19,6 +19,7 @@ def parse_args():
parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--no_cache', action="store_true")
parser.add_argument('--description_dict_path', default=None)
return parser.parse_args()
......@@ -34,15 +35,21 @@ def main():
else:
task_names = args.tasks.split(",")
description_dict = {}
if args.description_dict_path:
with open(args.description_dict_path, 'r') as f:
description_dict = json.load(f)
results = evaluator.simple_evaluate(
model=args.model,
model_args=args.model_args,
task_names=task_names,
tasks=task_names,
num_fewshot=args.num_fewshot,
batch_size=args.batch_size,
device=args.device,
no_cache=args.no_cache,
limit=args.limit,
description_dict=description_dict
)
dumped = json.dumps(results, indent=2)
......
......@@ -51,7 +51,14 @@ def main():
values = []
for taskname in task_list.split(","):
lm.tokencost = 0
evaluator.evaluate(lm, {taskname: tasks.get_task(taskname)()}, False, 0, None, bootstrap_iters=10)
evaluator.evaluate(
lm=lm,
task_dict={taskname: tasks.get_task(taskname)()},
num_fewshot=0,
limit=None,
bootstrap_iters=10,
description_dict=None
)
print(taskname, lm.tokencost)
values.append([taskname, lm.tokencost, lm.tokencost / 1000 * 0.0008, lm.tokencost / 1000 * 0.0012, lm.tokencost / 1000 * 0.006, lm.tokencost / 1000 * 0.06])
......
import json
import numpy as np
import random
import logging
from lm_eval import models, tasks, evaluator, base
logging.getLogger("openai").setLevel(logging.WARNING)
fewshot_descriptions = [
"foo",
"bar"
]
task = "lambada"
num_fewshot = 0
model = "gpt2"
model_args = ""
limit = None
no_cache = False
class CustomDescTask:
def __init__(self, task, desc):
self.task = task
self.desc = desc
def fewshot_description():
return self.desc
self.task.fewshot_description = fewshot_description
def __getattr__(self, attr):
return getattr(self.task, attr)
def main():
random.seed(42)
np.random.seed(42)
lm = models.get_model(model).create_from_arg_string(model_args)
if limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
if not no_cache:
lm = base.CachingLM(lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_') + '.db')
task_dict = tasks.get_task_dict([task])
for desc in fewshot_descriptions:
custom_task_dict = {k: CustomDescTask(v, desc) for k, v in task_dict.items()}
results = evaluator.evaluate(lm, custom_task_dict, True, num_fewshot, limit)
dumped = json.dumps(results, indent=2)
print('Description:', desc)
print(dumped)
# MAKE TABLE
from pytablewriter import MarkdownTableWriter
writer = MarkdownTableWriter()
writer.headers = ["Task", "Metric", "Value"]
values = []
for k, dic in results.items():
for m, v in dic.items():
values.append([k, m, '%.4f' % v])
k = ""
writer.value_matrix = values
print(writer.dumps())
if __name__ == "__main__":
main()
......@@ -9,7 +9,6 @@ for tname, Task in tasks.TASK_REGISTRY.items():#[('record', tasks.superglue.ReCo
print('#', tname)
docs = islice(task.validation_docs() if task.has_validation_docs() else task.test_docs(), ct)
print()
print('**Zero-Shot Prompt**:', "\n```\n" + task.fewshot_description() + "\n```\n")
for i in range(ct):
print()
doc = next(docs)
......
import argparse
import numpy as np
import json
import os
import random
from lm_eval import tasks
......@@ -17,6 +18,7 @@ def parse_args():
parser.add_argument('--num_fewshot', type=int, default=1)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--num_examples', type=int, default=1)
parser.add_argument('--description_dict_path', default=None)
return parser.parse_args()
......@@ -29,6 +31,12 @@ def main():
else:
task_names = args.tasks.split(",")
task_dict = tasks.get_task_dict(task_names)
description_dict = {}
if args.description_dict_path:
with open(args.description_dict_path, 'r') as f:
description_dict = json.load(f)
os.makedirs(args.output_base_path, exist_ok=True)
for task_name, task in task_dict.items():
rnd = random.Random()
......@@ -47,14 +55,16 @@ def main():
docs = join_iters(iters)
description = description_dict[task_name] if description_dict and task_name in description_dict else ""
with open(os.path.join(args.output_base_path, task_name), "w") as f:
for i, doc in zip(range(args.num_examples), docs) if args.num_examples > 0 else enumerate(docs):
f.write(EXAMPLE_DIVIDER.format(i=i))
ctx = task.fewshot_context(
doc=doc,
provide_description=args.provide_description,
num_fewshot=args.num_fewshot,
rnd=rnd
rnd=rnd,
description=description
)
f.write(ctx + "\n")
......
import random
import lm_eval.tasks
import lm_eval.models
def test_description_dict():
seed = 42
num_examples = 1
task_names = ["hellaswag", "winogrande"]
description_dict = {
"hellaswag": "Label for the relevant action:\nSentences describing context, with an incomplete sentence trailing answer that plausibly completes the situation.",
"winogrande": "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in.",
}
task_dict = lm_eval.tasks.get_task_dict(task_names)
for task_name, task in task_dict.items():
rnd = random.Random()
rnd.seed(seed)
if task.has_training_docs():
docs = task.training_docs()
elif set == "val" and task.has_validation_docs():
docs = task.validation_docs()
elif set == "test" and task.has_test_docs():
docs = task.test_docs()
description = (
description_dict[task_name]
if description_dict and task_name in description_dict
else ""
)
for _, doc in (
zip(range(num_examples), docs) if num_examples > 0 else enumerate(docs)
):
ctx = task.fewshot_context(
doc=doc,
num_fewshot=1,
rnd=rnd,
description=description,
)
assert description in ctx
......@@ -48,8 +48,22 @@ def test_evaluator(taskname, task_class):
lm.loglikelihood_rolling = ll_perp_fn
limit = 10
e1 = evaluator.evaluate(lm, task_dict, False, 0, limit, bootstrap_iters=10)
e2 = evaluator.evaluate(lm, task_dict, False, 0, limit, bootstrap_iters=10)
e1 = evaluator.evaluate(
lm=lm,
task_dict=task_dict,
num_fewshot=0,
limit=limit,
bootstrap_iters=10,
description_dict=None
)
e2 = evaluator.evaluate(
lm=lm,
task_dict=task_dict,
num_fewshot=0,
limit=limit,
bootstrap_iters=10,
description_dict=None
)
# check that caching is working
assert e1 == e2
......@@ -99,5 +99,13 @@ def test_versions_stable(taskname, task_class):
lm.greedy_until = greedy_until
limit = None
result = evaluator.evaluate(lm, task_dict, False, 0, limit, bootstrap_iters=10)
result = evaluator.evaluate(
lm=lm,
task_dict=task_dict,
num_fewshot=0,
limit=limit,
bootstrap_iters=10,
description_dict=None
)
assert_target(f"{taskname}-v{task_class.VERSION}-res", result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment