Commit 8ebe36b2 authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Add positional arg deprecation decorator

parent d34ae3cf
...@@ -457,6 +457,7 @@ class Task(abc.ABC): ...@@ -457,6 +457,7 @@ class Task(abc.ABC):
DeprecationWarning) DeprecationWarning)
return "" return ""
@utils.positional_deprecated
def fewshot_context(self, doc, num_fewshot, provide_description, rnd, description=None): def fewshot_context(self, doc, num_fewshot, provide_description, rnd, description=None):
assert not provide_description, ( assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend " "The `provide_description` arg will be removed in future versions. To prepend "
......
...@@ -6,8 +6,10 @@ import lm_eval.models ...@@ -6,8 +6,10 @@ import lm_eval.models
import lm_eval.tasks import lm_eval.tasks
import lm_eval.base import lm_eval.base
import numpy as np import numpy as np
from lm_eval.utils import positional_deprecated
@positional_deprecated
def simple_evaluate(model, model_args, task_names, def simple_evaluate(model, model_args, task_names,
num_fewshot=0, batch_size=None, device=None, num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000, no_cache=False, limit=None, bootstrap_iters=100000,
...@@ -51,7 +53,14 @@ def simple_evaluate(model, model_args, task_names, ...@@ -51,7 +53,14 @@ def simple_evaluate(model, model_args, task_names,
task_dict = lm_eval.tasks.get_task_dict(task_names) task_dict = lm_eval.tasks.get_task_dict(task_names)
results = evaluate(lm, task_dict, False, num_fewshot, limit, description_dict=description_dict) results = evaluate(
lm=lm,
task_dict=task_dict,
provide_description=False,
num_fewshot=num_fewshot,
limit=limit,
description_dict=description_dict
)
# add info about the model and few shot config # add info about the model and few shot config
results["config"] = { results["config"] = {
...@@ -69,6 +78,7 @@ def simple_evaluate(model, model_args, task_names, ...@@ -69,6 +78,7 @@ def simple_evaluate(model, model_args, task_names,
return results return results
@positional_deprecated
def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_iters=100000, description_dict=None): def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_iters=100000, description_dict=None):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
......
...@@ -38,7 +38,13 @@ class PROST(HFTask, MultipleChoiceTask): ...@@ -38,7 +38,13 @@ class PROST(HFTask, MultipleChoiceTask):
def fewshot_context(self, doc, num_fewshot, provide_description, rnd, description=None): def fewshot_context(self, doc, num_fewshot, provide_description, rnd, description=None):
assert num_fewshot == 0, 'PROST is designed to probe models in a zero-shot fashion only.' assert num_fewshot == 0, 'PROST is designed to probe models in a zero-shot fashion only.'
return super().fewshot_context(doc, num_fewshot, provide_description, rnd, description) return super().fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
provide_description=provide_description,
rnd=rnd,
description=description
)
def _convert_standard(self, doc): def _convert_standard(self, doc):
out_doc = { out_doc = {
......
...@@ -87,7 +87,13 @@ class TruthfulQAMultipleChoice(Task): ...@@ -87,7 +87,13 @@ class TruthfulQAMultipleChoice(Task):
def fewshot_context(self, doc, num_fewshot, provide_description, rnd, description=None): def fewshot_context(self, doc, num_fewshot, provide_description, rnd, description=None):
assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting." assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting."
return super().fewshot_context(doc, num_fewshot, provide_description, rnd, description) return super().fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
provide_description=provide_description,
rnd=rnd,
description=description
)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
...@@ -219,7 +225,12 @@ class TruthfulQAGeneration(Task): ...@@ -219,7 +225,12 @@ class TruthfulQAGeneration(Task):
def fewshot_context(self, doc, num_fewshot, provide_description, rnd, description=None): def fewshot_context(self, doc, num_fewshot, provide_description, rnd, description=None):
assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting." assert num_fewshot == 0, "TruthfulQA is intended only for the zero-shot setting."
return super().fewshot_context(doc, num_fewshot, provide_description, rnd, description) return super().fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
provide_description=provide_description,
rnd=rnd,
description=description)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
......
import os import os
import re import re
import collections import collections
import functools
class ExitCodeError(Exception): class ExitCodeError(Exception):
...@@ -139,3 +140,17 @@ class Reorderer: ...@@ -139,3 +140,17 @@ class Reorderer:
assert all(cov) assert all(cov)
return res return res
def positional_deprecated(fn):
"""
A decorator to nudge users into passing only keyword args (`kwargs`) to the
wrapped function, `fn`.
"""
@functools.wraps(fn)
def _wrapper(*args, **kwargs):
if len(args) != 0:
print(f"WARNING: using {fn.__name__} with positional arguments is "
"deprecated and will be disallowed in a future version of "
"lm-evaluation-harness!")
return fn(*args, **kwargs)
return _wrapper
...@@ -51,7 +51,15 @@ def main(): ...@@ -51,7 +51,15 @@ def main():
values = [] values = []
for taskname in task_list.split(","): for taskname in task_list.split(","):
lm.tokencost = 0 lm.tokencost = 0
evaluator.evaluate(lm, {taskname: tasks.get_task(taskname)()}, False, 0, None, bootstrap_iters=10) evaluator.evaluate(
lm=lm,
task_dict={taskname: tasks.get_task(taskname)()},
provide_description=False,
num_fewshot=0,
limit=None,
bootstrap_iters=10,
description_dict=None
)
print(taskname, lm.tokencost) print(taskname, lm.tokencost)
values.append([taskname, lm.tokencost, lm.tokencost / 1000 * 0.0008, lm.tokencost / 1000 * 0.0012, lm.tokencost / 1000 * 0.006, lm.tokencost / 1000 * 0.06]) values.append([taskname, lm.tokencost, lm.tokencost / 1000 * 0.0008, lm.tokencost / 1000 * 0.0012, lm.tokencost / 1000 * 0.006, lm.tokencost / 1000 * 0.06])
......
...@@ -48,8 +48,24 @@ def test_evaluator(taskname, task_class): ...@@ -48,8 +48,24 @@ def test_evaluator(taskname, task_class):
lm.loglikelihood_rolling = ll_perp_fn lm.loglikelihood_rolling = ll_perp_fn
limit = 10 limit = 10
e1 = evaluator.evaluate(lm, task_dict, False, 0, limit, bootstrap_iters=10, description_dict=None) e1 = evaluator.evaluate(
e2 = evaluator.evaluate(lm, task_dict, False, 0, limit, bootstrap_iters=10, description_dict=None) lm=lm,
task_dict=task_dict,
provide_description=False,
num_fewshot=0,
limit=limit,
bootstrap_iters=10,
description_dict=None
)
e2 = evaluator.evaluate(
lm=lm,
task_dict=task_dict,
provide_description=False,
num_fewshot=0,
limit=limit,
bootstrap_iters=10,
description_dict=None
)
# check that caching is working # check that caching is working
assert e1 == e2 assert e1 == e2
...@@ -99,5 +99,14 @@ def test_versions_stable(taskname, task_class): ...@@ -99,5 +99,14 @@ def test_versions_stable(taskname, task_class):
lm.greedy_until = greedy_until lm.greedy_until = greedy_until
limit = None limit = None
result = evaluator.evaluate(lm, task_dict, False, 0, limit, bootstrap_iters=10) result = evaluator.evaluate(
lm=lm,
task_dict=task_dict,
provide_description=False,
num_fewshot=0,
limit=limit,
bootstrap_iters=10,
description_dict=None
)
assert_target(f"{taskname}-v{task_class.VERSION}-res", result) assert_target(f"{taskname}-v{task_class.VERSION}-res", result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment