Unverified Commit 92f30afd authored by Vicki Boykis's avatar Vicki Boykis Committed by GitHub
Browse files

Proposed approach for testing CLI arg parsing (#1566)

* New tests for CLI args

* fix spacing

* change tests for parsing

* add tests, fix parser

* remove defaults for store_true
parent dc90fecc
...@@ -53,13 +53,30 @@ def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","): ...@@ -53,13 +53,30 @@ def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","):
return items return items
def parse_eval_args() -> argparse.Namespace: def check_argument_types(parser: argparse.ArgumentParser):
"""
Check to make sure all CLI args are typed, raises error if not
"""
for action in parser._actions:
if action.dest != "help" and not action.const:
if action.type is None:
raise ValueError(
f"Argument '{action.dest}' doesn't have a type specified."
)
else:
continue
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model", "-m", default="hf", help="Name of model e.g. `hf`") parser.add_argument(
"--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`"
)
parser.add_argument( parser.add_argument(
"--tasks", "--tasks",
"-t", "-t",
default=None, default=None,
type=str,
metavar="task1,task2", metavar="task1,task2",
help="To get full list of tasks, use the command lm-eval --tasks list", help="To get full list of tasks, use the command lm-eval --tasks list",
) )
...@@ -67,6 +84,7 @@ def parse_eval_args() -> argparse.Namespace: ...@@ -67,6 +84,7 @@ def parse_eval_args() -> argparse.Namespace:
"--model_args", "--model_args",
"-a", "-a",
default="", default="",
type=str,
help="Comma separated string arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`", help="Comma separated string arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
) )
parser.add_argument( parser.add_argument(
...@@ -164,6 +182,7 @@ def parse_eval_args() -> argparse.Namespace: ...@@ -164,6 +182,7 @@ def parse_eval_args() -> argparse.Namespace:
) )
parser.add_argument( parser.add_argument(
"--gen_kwargs", "--gen_kwargs",
type=dict,
default=None, default=None,
help=( help=(
"String arguments for model generation on greedy_until tasks," "String arguments for model generation on greedy_until tasks,"
...@@ -180,6 +199,7 @@ def parse_eval_args() -> argparse.Namespace: ...@@ -180,6 +199,7 @@ def parse_eval_args() -> argparse.Namespace:
) )
parser.add_argument( parser.add_argument(
"--wandb_args", "--wandb_args",
type=str,
default="", default="",
help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval", help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
) )
...@@ -209,13 +229,19 @@ def parse_eval_args() -> argparse.Namespace: ...@@ -209,13 +229,19 @@ def parse_eval_args() -> argparse.Namespace:
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub", help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
) )
return parser
def parse_eval_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
check_argument_types(parser)
return parser.parse_args() return parser.parse_args()
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if not args: if not args:
# we allow for args to be passed externally, else we parse them ourselves # we allow for args to be passed externally, else we parse them ourselves
args = parse_eval_args() parser = setup_parser()
args = parse_eval_args(parser)
if args.wandb_args: if args.wandb_args:
wandb_logger = WandbLogger(**simple_parse_args_string(args.wandb_args)) wandb_logger = WandbLogger(**simple_parse_args_string(args.wandb_args))
......
import argparse
import pytest
import lm_eval.__main__
def test_cli_parse_error():
"""
Assert error raised if cli args argument doesn't have type
"""
with pytest.raises(ValueError):
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument(
"--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`"
)
parser.add_argument(
"--tasks",
"-t",
default=None,
metavar="task1,task2",
help="To get full list of tasks, use the command lm-eval --tasks list",
)
lm_eval.__main__.check_argument_types(parser)
def test_cli_parse_no_error():
"""
Assert typed arguments are parsed correctly
"""
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument(
"--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`"
)
parser.add_argument(
"--tasks",
"-t",
type=str,
default=None,
metavar="task1,task2",
help="To get full list of tasks, use the command lm-eval --tasks list",
)
lm_eval.__main__.check_argument_types(parser)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment