Commit 0e232f7a authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Update new `task` arg and task dict getter

parent 57d0718a
...@@ -244,10 +244,13 @@ def get_task_name_from_object(task_object): ...@@ -244,10 +244,13 @@ def get_task_name_from_object(task_object):
def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]): def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]):
return { task_name_dict = {
task_name: get_task(task_name)() task_name: get_task(task_name)()
for task_name in task_name_list if isinstance(task_name, str) for task_name in task_name_list if isinstance(task_name, str)
} + { }
task_name_from_object_dict = {
get_task_name_from_object(task_object): task_object get_task_name_from_object(task_object): task_object
for task_object in task_name_list if not isinstance(task_object, str) for task_object in task_name_list if not isinstance(task_object, str)
} }
assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {**task_name_dict, **task_name_from_object_dict}
...@@ -36,12 +36,11 @@ class PROST(HFTask, MultipleChoiceTask): ...@@ -36,12 +36,11 @@ class PROST(HFTask, MultipleChoiceTask):
def has_test_docs(self): def has_test_docs(self):
return True return True
def fewshot_context(self, doc, num_fewshot, provide_description, rnd, description=None): def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, 'PROST is designed to probe models in a zero-shot fashion only.' assert num_fewshot == 0, 'PROST is designed to probe models in a zero-shot fashion only.'
return super().fewshot_context( return super().fewshot_context(
doc=doc, doc=doc,
num_fewshot=num_fewshot, num_fewshot=num_fewshot,
provide_description=provide_description,
rnd=rnd, rnd=rnd,
description=description description=description
) )
......
...@@ -43,7 +43,7 @@ def main(): ...@@ -43,7 +43,7 @@ def main():
results = evaluator.simple_evaluate( results = evaluator.simple_evaluate(
model=args.model, model=args.model,
model_args=args.model_args, model_args=args.model_args,
task_names=task_names, tasks=task_names,
num_fewshot=args.num_fewshot, num_fewshot=args.num_fewshot,
batch_size=args.batch_size, batch_size=args.batch_size,
device=args.device, device=args.device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment