Commit 0e232f7a authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Update new `task` arg and task dict getter

parent 57d0718a
......@@ -244,10 +244,13 @@ def get_task_name_from_object(task_object):
def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]):
return {
task_name_dict = {
task_name: get_task(task_name)()
for task_name in task_name_list if isinstance(task_name, str)
} + {
}
task_name_from_object_dict = {
get_task_name_from_object(task_object): task_object
for task_object in task_name_list if not isinstance(task_object, str)
}
assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {**task_name_dict, **task_name_from_object_dict}
......@@ -36,12 +36,11 @@ class PROST(HFTask, MultipleChoiceTask):
def has_test_docs(self):
return True
def fewshot_context(self, doc, num_fewshot, provide_description, rnd, description=None):
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0, 'PROST is designed to probe models in a zero-shot fashion only.'
return super().fewshot_context(
doc=doc,
num_fewshot=num_fewshot,
provide_description=provide_description,
rnd=rnd,
description=description
)
......
......@@ -43,7 +43,7 @@ def main():
results = evaluator.simple_evaluate(
model=args.model,
model_args=args.model_args,
task_names=task_names,
tasks=task_names,
num_fewshot=args.num_fewshot,
batch_size=args.batch_size,
device=args.device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment