Commit d7a8ab24 authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Add basic `description_dict`

test
parent ee53be21
import random
import lm_eval.tasks
import lm_eval.models
def test_description_dict():
seed = 42
num_examples = 1
task_names = ["hellaswag", "winogrande"]
description_dict = {
"hellaswag": "Label for the relevant action:\nSentences describing context, with an incomplete sentence trailing answer that plausibly completes the situation.",
"winogrande": "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in."
}
task_dict = lm_eval.tasks.get_task_dict(task_names)
for task_name, task in task_dict.items():
rnd = random.Random()
rnd.seed(seed)
if task.has_training_docs():
docs = task.training_docs()
elif set == "val" and task.has_validation_docs():
docs = task.validation_docs()
elif set == "test" and task.has_test_docs():
docs = task.test_docs()
description = (
description_dict[task_name]
if description_dict and task_name in description_dict
else ""
)
for _, doc in (
zip(range(num_examples), docs)
if num_examples > 0
else enumerate(docs)
):
ctx = task.fewshot_context(
doc=doc,
num_fewshot=1,
provide_description=False,
rnd=rnd,
description=description,
)
print(ctx + "\n\n")
assert description in ctx
test_description_dict()
\ No newline at end of file
{
"hellaswag": "Label for the relevant action:\nSentences describing context, with an incomplete sentence trailing answer that plausibly completes the situation."
}
import json
import argparse
import lm_eval.tasks
import lm_eval.models
from lm_eval.evaluator import evaluate
def test_cli_description_dict_path():
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--description_dict_path', default=None)
parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--limit', type=int, default=None)
return parser.parse_args()
args = parse_args()
task_names = ['hellaswag', 'copa']
task_dict = lm_eval.tasks.get_task_dict(task_names)
lm = lm_eval.models.get_model('dummy')()
description_dict = {}
if args.description_dict_path:
with open(args.description_dict_path, 'r') as f:
description_dict = json.load(f)
num_fewshot = args.num_fewshot
results = evaluate(
lm,
task_dict,
False,
num_fewshot,
args.limit,
description_dict
)
if __name__ == '__main__':
test_cli_description_dict_path()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment