"vscode:/vscode.git/clone" did not exist on "75139ca33853c88968a6781e0b15992184e78446"
Commit 41ecaba8 authored by Baber's avatar Baber
Browse files

add to task_manager

parent 86cf5dc9
...@@ -262,6 +262,11 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -262,6 +262,11 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true", action="store_true",
help="Confirm that you understand the risks of running unsafe code for tasks that require it", help="Confirm that you understand the risks of running unsafe code for tasks that require it",
) )
parser.add_argument(
"--mcq_to_generative",
action="store_true",
help="Convert multiple choice tasks to generative tasks, for models that don't support logit outputs",
)
return parser return parser
...@@ -306,7 +311,11 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -306,7 +311,11 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if args.include_path is not None: if args.include_path is not None:
eval_logger.info(f"Including path: {args.include_path}") eval_logger.info(f"Including path: {args.include_path}")
task_manager = TaskManager(args.verbosity, include_path=args.include_path) task_manager = TaskManager(
args.verbosity,
include_path=args.include_path,
mcq_to_generative=args.mcq_to_generative,
)
if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples: if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples:
eval_logger.warning( eval_logger.warning(
...@@ -410,6 +419,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -410,6 +419,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
torch_random_seed=args.seed[2], torch_random_seed=args.seed[2],
fewshot_random_seed=args.seed[3], fewshot_random_seed=args.seed[3],
confirm_run_unsafe_code=args.confirm_run_unsafe_code, confirm_run_unsafe_code=args.confirm_run_unsafe_code,
mcq_to_generative=args.mcq_to_generative,
**request_caching_args, **request_caching_args,
) )
......
...@@ -75,6 +75,7 @@ def simple_evaluate( ...@@ -75,6 +75,7 @@ def simple_evaluate(
torch_random_seed: int = 1234, torch_random_seed: int = 1234,
fewshot_random_seed: int = 1234, fewshot_random_seed: int = 1234,
confirm_run_unsafe_code: bool = False, confirm_run_unsafe_code: bool = False,
mcq_to_generative: bool = False,
): ):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
...@@ -231,7 +232,7 @@ def simple_evaluate( ...@@ -231,7 +232,7 @@ def simple_evaluate(
) )
if task_manager is None: if task_manager is None:
task_manager = TaskManager(verbosity) task_manager = TaskManager(verbosity, mcq_to_generative=mcq_to_generative)
task_dict = get_task_dict(tasks, task_manager) task_dict = get_task_dict(tasks, task_manager)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment