Commit ec5846d7 authored by USVSN Sai Prashanth's avatar USVSN Sai Prashanth
Browse files

Allow Generation arguments on greedy_until reqs

parent efa38caa
...@@ -21,6 +21,7 @@ from lm_eval.utils import ( ...@@ -21,6 +21,7 @@ from lm_eval.utils import (
make_table, make_table,
create_iterator, create_iterator,
get_git_commit_hash, get_git_commit_hash,
simple_parse_args_string
) )
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
...@@ -46,6 +47,7 @@ def simple_evaluate( ...@@ -46,6 +47,7 @@ def simple_evaluate(
decontamination_ngrams_path=None, decontamination_ngrams_path=None,
write_out: bool = False, write_out: bool = False,
log_samples: bool = True, log_samples: bool = True,
gen_kwargs: str = None
): ):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
...@@ -76,6 +78,9 @@ def simple_evaluate( ...@@ -76,6 +78,9 @@ def simple_evaluate(
If True, write out an example document and model input for checking task integrity If True, write out an example document and model input for checking task integrity
:param log_samples: bool :param log_samples: bool
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
:param gen_kwargs: str
String arguments for model generation
Ignored for all tasks with loglikelihood output_type
:return :return
Dictionary of results Dictionary of results
""" """
...@@ -123,6 +128,10 @@ def simple_evaluate( ...@@ -123,6 +128,10 @@ def simple_evaluate(
continue continue
config = task_obj._config config = task_obj._config
if config['output_type'] == 'greedy_until' and gen_kwargs != "":
gen_kwargs = simple_parse_args_string(gen_kwargs)
config['generation_kwargs'].update(gen_kwargs)
if num_fewshot is not None: if num_fewshot is not None:
if config["num_fewshot"] > 0: if config["num_fewshot"] > 0:
default_num_fewshot = config["num_fewshot"] default_num_fewshot = config["num_fewshot"]
...@@ -160,6 +169,7 @@ def simple_evaluate( ...@@ -160,6 +169,7 @@ def simple_evaluate(
"use_cache": use_cache, "use_cache": use_cache,
"limit": limit, "limit": limit,
"bootstrap_iters": bootstrap_iters, "bootstrap_iters": bootstrap_iters,
"gen_kwargs": gen_kwargs
} }
results["git_hash"] = get_git_commit_hash() results["git_hash"] = get_git_commit_hash()
return results return results
......
...@@ -50,6 +50,11 @@ def handle_arg_string(arg): ...@@ -50,6 +50,11 @@ def handle_arg_string(arg):
return True return True
elif arg.lower() == "false": elif arg.lower() == "false":
return False return False
elif arg.isnumeric():
return int(arg)
try:
return float(arg)
except ValueError:
return arg return arg
......
...@@ -97,6 +97,12 @@ def parse_args() -> argparse.Namespace: ...@@ -97,6 +97,12 @@ def parse_args() -> argparse.Namespace:
default=None, default=None,
help="Additional path to include if there are external tasks to include.", help="Additional path to include if there are external tasks to include.",
) )
parser.add_argument(
"--gen_kwargs",
default="",
help=("String arguments for model generation on greedy_until tasks,"
" e.g. `temperature=0,top_k=0,top_p=0`")
)
return parser.parse_args() return parser.parse_args()
...@@ -179,6 +185,7 @@ def main() -> None: ...@@ -179,6 +185,7 @@ def main() -> None:
check_integrity=args.check_integrity, check_integrity=args.check_integrity,
write_out=args.write_out, write_out=args.write_out,
log_samples=args.log_samples, log_samples=args.log_samples,
gen_kwargs=args.gen_kwargs
) )
if results is not None: if results is not None:
...@@ -204,7 +211,7 @@ def main() -> None: ...@@ -204,7 +211,7 @@ def main() -> None:
f.write_all(samples[task_name]) f.write_all(samples[task_name])
print( print(
f"{args.model} ({args.model_args}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}" f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
) )
print(evaluator.make_table(results)) print(evaluator.make_table(results))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment