Unverified Commit bfbd0325 authored by Amine Elhattami's avatar Amine Elhattami Committed by GitHub
Browse files

Added seeds to `evaluator.simple_evaluate` signature (#1412)

* Added seeds to `evaluator.simple_evaluate` signature

* Added  CLI argument

* Updated  to add  arg.
parent b69c67c1
......@@ -46,6 +46,8 @@ This mode supports a number of command-line arguments, the details of which can
* `--predict_only`: Generates the model outputs without computing metrics. Use with `--log_samples` to retrieve decoded results.
* `--seed`: Set seed for python's random, numpy and torch. Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, or a single integer to set the same seed for all three. The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility). E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`. E.g, `--seed 42` sets all three seeds to 42.
## External Library Usage
We also support using the library's external API for use within model training loops or other scripts.
......
......@@ -4,6 +4,7 @@ import logging
import os
import re
import sys
from functools import partial
from pathlib import Path
from typing import Union
......@@ -23,6 +24,30 @@ def _handle_non_serializable(o):
return str(o)
def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","):
def parse_value(item):
item = item.strip().lower()
if item == "none":
return None
try:
return int(item)
except ValueError:
raise argparse.ArgumentTypeError(f"{item} is not an integer or None")
items = [parse_value(v) for v in value.split(split_char)]
num_items = len(items)
if num_items == 1:
# Makes downstream handling the same for single and multiple values
items = items * max_len
elif num_items != max_len:
raise argparse.ArgumentTypeError(
f"Argument requires {max_len} integers or None, separated by '{split_char}'"
)
return items
def parse_eval_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model", "-m", default="hf", help="Name of model e.g. `hf`")
......@@ -149,6 +174,19 @@ def parse_eval_args() -> argparse.Namespace:
default=False,
help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
)
parser.add_argument(
"--seed",
type=partial(_int_or_none_list_arg_type, 3),
default="0,1234,1234", # for backward compatibility
help=(
"Set seed for python's random, numpy and torch.\n"
"Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, "
"or a single integer to set the same seed for all three.\n"
"The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility).\n"
"E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`.\n"
"E.g, `--seed 42` sets all three seeds to 42."
),
)
return parser.parse_args()
......@@ -255,6 +293,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
gen_kwargs=args.gen_kwargs,
task_manager=task_manager,
predict_only=args.predict_only,
random_seed=args.seed[0],
numpy_random_seed=args.seed[1],
torch_random_seed=args.seed[2],
)
if results is not None:
......
......@@ -40,6 +40,9 @@ def simple_evaluate(
task_manager: TaskManager = None,
verbosity: str = "INFO",
predict_only: bool = False,
random_seed: int = 0,
numpy_random_seed: int = 1234,
torch_random_seed: int = 1234,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -75,18 +78,31 @@ def simple_evaluate(
Ignored for all tasks with loglikelihood output_type
:param predict_only: bool
If true only model outputs will be generated and returned. Metrics will not be evaluated
:param random_seed: int
Random seed for python's random module. If set to None, the seed will not be set.
:param numpy_random_seed: int
Random seed for numpy. If set to None, the seed will not be set.
:param torch_random_seed: int
Random seed for torch. If set to None, the seed will not be set.
:return
Dictionary of results
"""
random.seed(0)
np.random.seed(1234)
torch.manual_seed(
1234
) # TODO: this may affect training runs that are run with evaluation mid-run.
eval_logger.setLevel(getattr(logging, f"{verbosity}"))
if random_seed is not None:
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
eval_logger.info(f"Setting random seed to {random_seed}")
random.seed(random_seed)
if numpy_random_seed is not None:
eval_logger.info(f"Setting numpy seed to {numpy_random_seed}")
np.random.seed(numpy_random_seed)
if torch_random_seed is not None:
eval_logger.info(f"Setting torch manual seed to {torch_random_seed}")
torch.manual_seed(torch_random_seed)
if tasks is None:
tasks = []
assert (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment