Commit 8ebe36b2 authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Add positional arg deprecation decorator

parent d34ae3cf
......@@ -457,6 +457,7 @@ class Task(abc.ABC):
DeprecationWarning)
return ""
@utils.positional_deprecated
def fewshot_context(self, doc, num_fewshot, provide_description, rnd, description=None):
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
......
......@@ -6,8 +6,10 @@ import lm_eval.models
import lm_eval.tasks
import lm_eval.base
import numpy as np
from lm_eval.utils import positional_deprecated
@positional_deprecated
def simple_evaluate(model, model_args, task_names,
num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000,
......@@ -51,7 +53,14 @@ def simple_evaluate(model, model_args, task_names,
task_dict = lm_eval.tasks.get_task_dict(task_names)
results = evaluate(lm, task_dict, False, num_fewshot, limit, description_dict=description_dict)
results = evaluate(
lm=lm,
task_dict=task_dict,
provide_description=False,
num_fewshot=num_fewshot,
limit=limit,
description_dict=description_dict
)
# add info about the model and few shot config
results["config"] = {
......@@ -69,6 +78,7 @@ def simple_evaluate(model, model_args, task_names,
return results
@positional_deprecated
def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_iters=100000, description_dict=None):
"""Instantiate and evaluate a model on a list of tasks.
......
......@@ -38,7 +38,13 @@ class PROST(HFTask, MultipleChoiceTask):
def fewshot_context(self, doc, num_fewshot, provide_description, rnd, description=None):
assert num_fewshot == 0, 'PROST is designed to probe models in a zero-shot fashion only.'
return super().fewshot_context(doc, num_fewshot, provide_description, rnd, description)
return super().fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
provide_description=provide_description,
rnd=rnd,
description=description
)
def _convert_standard(self, doc):
out_doc = {
......
......@@ -87,7 +87,13 @@ class TruthfulQAMultipleChoice(Task):
def fewshot_context(self, doc, num_fewshot, provide_description, rnd, description=None):
assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting."
return super().fewshot_context(doc, num_fewshot, provide_description, rnd, description)
return super().fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
provide_description=provide_description,
rnd=rnd,
description=description
)
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
......@@ -219,7 +225,12 @@ class TruthfulQAGeneration(Task):
def fewshot_context(self, doc, num_fewshot, provide_description, rnd, description=None):
assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting."
return super().fewshot_context(doc, num_fewshot, provide_description, rnd, description)
return super().fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
provide_description=provide_description,
rnd=rnd,
description=description)
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
......
import os
import re
import collections
import functools
class ExitCodeError(Exception):
......@@ -138,4 +139,18 @@ class Reorderer:
assert all(cov)
return res
\ No newline at end of file
return res
def positional_deprecated(fn):
"""
A decorator to nudge users into passing only keyword args (`kwargs`) to the
wrapped function, `fn`.
"""
@functools.wraps(fn)
def _wrapper(*args, **kwargs):
if len(args) != 0:
print(f"WARNING: using {fn.__name__} with positional arguments is "
"deprecated and will be disallowed in a future version of "
"lm-evaluation-harness!")
return fn(*args, **kwargs)
return _wrapper
......@@ -51,7 +51,15 @@ def main():
values = []
for taskname in task_list.split(","):
lm.tokencost = 0
evaluator.evaluate(lm, {taskname: tasks.get_task(taskname)()}, False, 0, None, bootstrap_iters=10)
evaluator.evaluate(
lm=lm,
task_dict={taskname: tasks.get_task(taskname)()},
provide_description=False,
num_fewshot=0,
limit=None,
bootstrap_iters=10,
description_dict=None
)
print(taskname, lm.tokencost)
values.append([taskname, lm.tokencost, lm.tokencost / 1000 * 0.0008, lm.tokencost / 1000 * 0.0012, lm.tokencost / 1000 * 0.006, lm.tokencost / 1000 * 0.06])
......
......@@ -48,8 +48,24 @@ def test_evaluator(taskname, task_class):
lm.loglikelihood_rolling = ll_perp_fn
limit = 10
e1 = evaluator.evaluate(lm, task_dict, False, 0, limit, bootstrap_iters=10, description_dict=None)
e2 = evaluator.evaluate(lm, task_dict, False, 0, limit, bootstrap_iters=10, description_dict=None)
e1 = evaluator.evaluate(
lm=lm,
task_dict=task_dict,
provide_description=False,
num_fewshot=0,
limit=limit,
bootstrap_iters=10,
description_dict=None
)
e2 = evaluator.evaluate(
lm=lm,
task_dict=task_dict,
provide_description=False,
num_fewshot=0,
limit=limit,
bootstrap_iters=10,
description_dict=None
)
# check that caching is working
assert e1 == e2
......@@ -99,5 +99,14 @@ def test_versions_stable(taskname, task_class):
lm.greedy_until = greedy_until
limit = None
result = evaluator.evaluate(lm, task_dict, False, 0, limit, bootstrap_iters=10)
result = evaluator.evaluate(
lm=lm,
task_dict=task_dict,
provide_description=False,
num_fewshot=0,
limit=limit,
bootstrap_iters=10,
description_dict=None
)
assert_target(f"{taskname}-v{task_class.VERSION}-res", result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment