Unverified Commit b189066d authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #897 from EleutherAI/orz

Allow Generation arguments on greedy_until reqs
parents a3d08277 d6653610
......@@ -18,6 +18,8 @@ This mode supports a number of command-line arguments, the details of which can
* `--num_fewshot` : Sets the number of few-shot examples to place in context. Must be an integer.
* `--gen_kwargs` : takes an arg string in same format as `--model_args` and creates a dictionary of keyword arguments. These will be passed to the models for all called `generate_until` (free-form or greedy generation task) tasks, to set options such as the sampling temperature or `top_p` / `top_k`. For a list of what args are supported for each model type, reference the respective library's documentation (for example, the documentation for `transformers.AutoModelForCausalLM.generate()`.) These kwargs will be applied to all `generate_until` tasks called--we do not currently support unique gen_kwargs or batch_size values per task in a single run of the library. To control these on a per-task level, set them in that task's YAML file.
* `--batch_size` : Sets the batch size used for evaluation. Can be a positive integer or `"auto"` to automatically select the largest batch size that will fit in memory, speeding up evaluation. One can pass `--batch_size auto:N` to re-select the maximum batch size `N` times during evaluation. This can help accelerate evaluation further, since `lm-eval` sorts documents in descending order of context length.
* `--max_batch_size` : Sets the maximum batch size to try to fit in memory, if `--batch_size auto` is passed.
......
......@@ -105,6 +105,14 @@ def parse_eval_args() -> argparse.Namespace:
default=None,
help="Additional path to include if there are external tasks to include.",
)
parser.add_argument(
"--gen_kwargs",
default="",
help=(
"String arguments for model generation on greedy_until tasks,"
" e.g. `temperature=0,top_k=0,top_p=0`"
),
)
parser.add_argument(
"--verbosity",
type=str,
......@@ -210,6 +218,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
check_integrity=args.check_integrity,
write_out=args.write_out,
log_samples=args.log_samples,
gen_kwargs=args.gen_kwargs,
)
if results is not None:
......@@ -236,7 +245,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
filename.open("w").write(samples_dumped)
print(
f"{args.model} ({args.model_args}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
)
print(evaluator.make_table(results))
......
......@@ -20,6 +20,7 @@ from lm_eval.utils import (
make_table,
create_iterator,
get_git_commit_hash,
simple_parse_args_string,
eval_logger,
)
......@@ -40,6 +41,7 @@ def simple_evaluate(
decontamination_ngrams_path=None,
write_out: bool = False,
log_samples: bool = True,
gen_kwargs: str = None,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -70,6 +72,9 @@ def simple_evaluate(
If True, write out an example document and model input for checking task integrity
:param log_samples: bool
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
:param gen_kwargs: str
String arguments for model generation
Ignored for all tasks with loglikelihood output_type
:return
Dictionary of results
"""
......@@ -83,6 +88,14 @@ def simple_evaluate(
tasks != []
), "No tasks specified, or no tasks found. Please verify the task names."
if gen_kwargs is not None:
gen_kwargs = simple_parse_args_string(gen_kwargs)
eval_logger.warning(
f"generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks."
)
if gen_kwargs == "":
gen_kwargs = None
if isinstance(model, str):
if model_args is None:
model_args = ""
......@@ -117,6 +130,9 @@ def simple_evaluate(
continue
config = task_obj._config
if config["output_type"] == "generate_until" and gen_kwargs is not None:
config["generation_kwargs"].update(gen_kwargs)
if num_fewshot is not None:
if config["num_fewshot"] > 0:
default_num_fewshot = config["num_fewshot"]
......@@ -154,6 +170,7 @@ def simple_evaluate(
"use_cache": use_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters,
"gen_kwargs": gen_kwargs,
}
results["git_hash"] = get_git_commit_hash()
return results
......
......@@ -59,7 +59,12 @@ def handle_arg_string(arg):
return True
elif arg.lower() == "false":
return False
return arg
elif arg.isnumeric():
return int(arg)
try:
return float(arg)
except ValueError:
return arg
def simple_parse_args_string(args_string):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment