Commit 3b5e554f authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

move to __main__.py

parent 6f92c20d
......@@ -12,10 +12,9 @@ from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger, SPACING
from lm_eval.tasks import include_task_folder
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from typing import Union
def parse_args() -> argparse.Namespace:
def parse_eval_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model", required=True, help="Name of model e.g. `hf`")
parser.add_argument(
......@@ -100,8 +99,13 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args()
def main() -> None:
args = parse_args()
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if not args:
# we allow for args to be passed externally, else we parse them ourselves
args = parse_eval_args()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if args.limit:
eval_logger.warning(
......@@ -212,5 +216,5 @@ def main() -> None:
print(evaluator.make_table(results, "groups"))
if __name__ == "__main__":
main()
if __name__ == "__main__":
cli_evaluate()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment