Commit 41ecaba8 authored by Baber's avatar Baber
Browse files

add to task_manager

parent 86cf5dc9
......@@ -262,6 +262,11 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="Confirm that you understand the risks of running unsafe code for tasks that require it",
)
parser.add_argument(
"--mcq_to_generative",
action="store_true",
help="Convert multiple choice tasks to generative tasks, for models that don't support logit outputs",
)
return parser
......@@ -306,7 +311,11 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if args.include_path is not None:
eval_logger.info(f"Including path: {args.include_path}")
task_manager = TaskManager(args.verbosity, include_path=args.include_path)
task_manager = TaskManager(
args.verbosity,
include_path=args.include_path,
mcq_to_generative=args.mcq_to_generative,
)
if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples:
eval_logger.warning(
......@@ -410,6 +419,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
torch_random_seed=args.seed[2],
fewshot_random_seed=args.seed[3],
confirm_run_unsafe_code=args.confirm_run_unsafe_code,
mcq_to_generative=args.mcq_to_generative,
**request_caching_args,
)
......
......@@ -75,6 +75,7 @@ def simple_evaluate(
torch_random_seed: int = 1234,
fewshot_random_seed: int = 1234,
confirm_run_unsafe_code: bool = False,
mcq_to_generative: bool = False,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -231,7 +232,7 @@ def simple_evaluate(
)
if task_manager is None:
task_manager = TaskManager(verbosity)
task_manager = TaskManager(verbosity, mcq_to_generative=mcq_to_generative)
task_dict = get_task_dict(tasks, task_manager)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment