Commit 1d04c42d authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Merge

parent ff314d62
...@@ -452,12 +452,17 @@ class Task(abc.ABC): ...@@ -452,12 +452,17 @@ class Task(abc.ABC):
def fewshot_description(self): def fewshot_description(self):
import warnings import warnings
warnings.warn( warnings.warn(
"`fewshot_description` will be removed in coming versions. Pass " \ "`fewshot_description` will be removed in futures versions. Pass " \
"any custom descriptions to the `evaluate` function instead.", "any custom descriptions to the `evaluate` function instead.",
DeprecationWarning) DeprecationWarning)
return "" return ""
def fewshot_context(self, doc, num_fewshot, rnd, description=None): def fewshot_context(self, doc, num_fewshot, provide_description, rnd, description=None):
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To provide "
"custom descriptions on a per-task basis, supply the `description_dict` "
"arg with your task-to-description dictionary."
)
description = description + "\n\n" if description else "" description = description + "\n\n" if description else ""
if num_fewshot == 0: if num_fewshot == 0:
...@@ -531,9 +536,13 @@ class PerplexityTask(Task, abc.ABC): ...@@ -531,9 +536,13 @@ class PerplexityTask(Task, abc.ABC):
assert k == 0 assert k == 0
return [] return []
def fewshot_context(self, doc, num_fewshot, rnd, description=None): def fewshot_context(self, doc, num_fewshot, provide_description, rnd, description=None):
assert num_fewshot == 0 assert num_fewshot == 0
assert description is None assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To provide "
"custom descriptions on a per-task basis, supply the `description_dict` "
"arg with your task-to-description dictionary."
)
return "" return ""
def higher_is_better(self): def higher_is_better(self):
......
...@@ -48,12 +48,13 @@ def simple_evaluate(model, model_args, task_names, ...@@ -48,12 +48,13 @@ def simple_evaluate(model, model_args, task_names,
) )
task_dict = lm_eval.tasks.get_task_dict(task_names) task_dict = lm_eval.tasks.get_task_dict(task_names)
description_dict = {} description_dict = {}
if description_path: if description_dict_path:
with open(description_path, 'r') as f: with open(description_dict_path, 'r') as f:
description_dict = json.load(f) description_dict = json.load(f)
results = evaluate(lm, task_dict, num_fewshot, limit, description_dict) results = evaluate(lm, task_dict, False, num_fewshot, limit, description_dict=description_dict)
# add info about the model and few shot config # add info about the model and few shot config
results["config"] = { results["config"] = {
...@@ -62,8 +63,6 @@ def simple_evaluate(model, model_args, task_names, ...@@ -62,8 +63,6 @@ def simple_evaluate(model, model_args, task_names,
"num_fewshot": num_fewshot, "num_fewshot": num_fewshot,
"batch_size": batch_size, "batch_size": batch_size,
"device": device, "device": device,
# TODO (jon-tow): Should we add the description info to `results["config"]`?
# "description_dict": description_dict,
"no_cache": no_cache, "no_cache": no_cache,
"limit": limit, "limit": limit,
"bootstrap_iters": bootstrap_iters "bootstrap_iters": bootstrap_iters
...@@ -140,6 +139,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i ...@@ -140,6 +139,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_i
ctx = task.fewshot_context( ctx = task.fewshot_context(
doc=doc, doc=doc,
num_fewshot=num_fewshot, num_fewshot=num_fewshot,
provide_description=provide_description,
rnd=rnd, rnd=rnd,
description=description description=description
) )
......
...@@ -12,13 +12,14 @@ def parse_args(): ...@@ -12,13 +12,14 @@ def parse_args():
parser.add_argument('--model', required=True) parser.add_argument('--model', required=True)
parser.add_argument('--model_args', default="") parser.add_argument('--model_args', default="")
parser.add_argument('--tasks', default="all_tasks") parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--description_path', default=None) parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0) parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None) parser.add_argument('--batch_size', type=int, default=None)
parser.add_argument('--device', type=str, default=None) parser.add_argument('--device', type=str, default=None)
parser.add_argument('--output_path', default=None) parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None) parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--no_cache', action="store_true") parser.add_argument('--no_cache', action="store_true")
parser.add_argument('--description_dict_path', default=None)
return parser.parse_args() return parser.parse_args()
......
...@@ -13,11 +13,12 @@ def parse_args(): ...@@ -13,11 +13,12 @@ def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--output_base_path', required=True) parser.add_argument('--output_base_path', required=True)
parser.add_argument('--tasks', default="all_tasks") parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--description_path', default=None) parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--sets', type=str, default="val") # example: val,test parser.add_argument('--sets', type=str, default="val") # example: val,test
parser.add_argument('--num_fewshot', type=int, default=1) parser.add_argument('--num_fewshot', type=int, default=1)
parser.add_argument('--seed', type=int, default=42) parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--num_examples', type=int, default=1) parser.add_argument('--num_examples', type=int, default=1)
parser.add_argument('--description_dict_path', default=None)
return parser.parse_args() return parser.parse_args()
...@@ -32,8 +33,8 @@ def main(): ...@@ -32,8 +33,8 @@ def main():
task_dict = tasks.get_task_dict(task_names) task_dict = tasks.get_task_dict(task_names)
description_dict = {} description_dict = {}
if args.description_path: if args.description_dict_path:
with open(args.description_path, 'r') as f: with open(args.description_dict_path, 'r') as f:
description_dict = json.load(f) description_dict = json.load(f)
os.makedirs(args.output_base_path, exist_ok=True) os.makedirs(args.output_base_path, exist_ok=True)
...@@ -62,6 +63,7 @@ def main(): ...@@ -62,6 +63,7 @@ def main():
ctx = task.fewshot_context( ctx = task.fewshot_context(
doc=doc, doc=doc,
num_fewshot=args.num_fewshot, num_fewshot=args.num_fewshot,
provide_description=args.provide_description,
rnd=rnd, rnd=rnd,
description=description description=description
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment