Commit 1d04c42d authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Merge

parent ff314d62
......@@ -452,12 +452,17 @@ class Task(abc.ABC):
def fewshot_description(self):
import warnings
warnings.warn(
"`fewshot_description` will be removed in coming versions. Pass " \
"`fewshot_description` will be removed in futures versions. Pass " \
"any custom descriptions to the `evaluate` function instead.",
DeprecationWarning)
return ""
def fewshot_context(self, doc, num_fewshot, rnd, description=None):
def fewshot_context(self, doc, num_fewshot, provide_description, rnd, description=None):
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To provide "
"custom descriptions on a per-task basis, supply the `description_dict` "
"arg with your task-to-description dictionary."
)
description = description + "\n\n" if description else ""
if num_fewshot == 0:
......@@ -531,9 +536,13 @@ class PerplexityTask(Task, abc.ABC):
assert k == 0
return []
def fewshot_context(self, doc, num_fewshot, rnd, description=None):
def fewshot_context(self, doc, num_fewshot, provide_description, rnd, description=None):
assert num_fewshot == 0
assert description is None
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To provide "
"custom descriptions on a per-task basis, supply the `description_dict` "
"arg with your task-to-description dictionary."
)
return ""
def higher_is_better(self):
......
......@@ -48,12 +48,13 @@ def simple_evaluate(model, model_args, task_names,
)
task_dict = lm_eval.tasks.get_task_dict(task_names)
description_dict = {}
if description_path:
with open(description_path, 'r') as f:
if description_dict_path:
with open(description_dict_path, 'r') as f:
description_dict = json.load(f)
results = evaluate(lm, task_dict, num_fewshot, limit, description_dict)
results = evaluate(lm, task_dict, False, num_fewshot, limit, description_dict=description_dict)
# add info about the model and few shot config
results["config"] = {
......@@ -62,8 +63,6 @@ def simple_evaluate(model, model_args, task_names,
"num_fewshot": num_fewshot,
"batch_size": batch_size,
"device": device,
# TODO (jon-tow): Should we add the description info to `results["config"]`?
# "description_dict": description_dict,
"no_cache": no_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters
......@@ -140,6 +139,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
ctx = task.fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
provide_description=provide_description,
rnd=rnd,
description=description
)
......
......@@ -12,13 +12,14 @@ def parse_args():
parser.add_argument('--model', required=True)
parser.add_argument('--model_args', default="")
parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--description_path', default=None)
parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None)
parser.add_argument('--device', type=str, default=None)
parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--no_cache', action="store_true")
parser.add_argument('--description_dict_path', default=None)
return parser.parse_args()
......
......@@ -13,11 +13,12 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--output_base_path', required=True)
parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--description_path', default=None)
parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--sets', type=str, default="val") # example: val,test
parser.add_argument('--num_fewshot', type=int, default=1)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--num_examples', type=int, default=1)
parser.add_argument('--description_dict_path', default=None)
return parser.parse_args()
......@@ -32,8 +33,8 @@ def main():
task_dict = tasks.get_task_dict(task_names)
description_dict = {}
if args.description_path:
with open(args.description_path, 'r') as f:
if args.description_dict_path:
with open(args.description_dict_path, 'r') as f:
description_dict = json.load(f)
os.makedirs(args.output_base_path, exist_ok=True)
......@@ -62,6 +63,7 @@ def main():
ctx = task.fewshot_context(
doc=doc,
num_fewshot=args.num_fewshot,
provide_description=args.provide_description,
rnd=rnd,
description=description
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment