Unverified Commit bfbd0325 authored by Amine Elhattami's avatar Amine Elhattami Committed by GitHub
Browse files

Added seeds to `evaluator.simple_evaluate` signature (#1412)

* Added seeds to `evaluator.simple_evaluate` signature

* Added  CLI argument

* Updated  to add  arg.
parent b69c67c1
...@@ -46,6 +46,8 @@ This mode supports a number of command-line arguments, the details of which can ...@@ -46,6 +46,8 @@ This mode supports a number of command-line arguments, the details of which can
* `--predict_only`: Generates the model outputs without computing metrics. Use with `--log_samples` to retrieve decoded results. * `--predict_only`: Generates the model outputs without computing metrics. Use with `--log_samples` to retrieve decoded results.
* `--seed`: Set seed for python's random, numpy and torch. Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, or a single integer to set the same seed for all three. The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility). E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`. E.g, `--seed 42` sets all three seeds to 42.
## External Library Usage ## External Library Usage
We also support using the library's external API for use within model training loops or other scripts. We also support using the library's external API for use within model training loops or other scripts.
......
...@@ -4,6 +4,7 @@ import logging ...@@ -4,6 +4,7 @@ import logging
import os import os
import re import re
import sys import sys
from functools import partial
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
...@@ -23,6 +24,30 @@ def _handle_non_serializable(o): ...@@ -23,6 +24,30 @@ def _handle_non_serializable(o):
return str(o) return str(o)
def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","):
def parse_value(item):
item = item.strip().lower()
if item == "none":
return None
try:
return int(item)
except ValueError:
raise argparse.ArgumentTypeError(f"{item} is not an integer or None")
items = [parse_value(v) for v in value.split(split_char)]
num_items = len(items)
if num_items == 1:
# Makes downstream handling the same for single and multiple values
items = items * max_len
elif num_items != max_len:
raise argparse.ArgumentTypeError(
f"Argument requires {max_len} integers or None, separated by '{split_char}'"
)
return items
def parse_eval_args() -> argparse.Namespace: def parse_eval_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model", "-m", default="hf", help="Name of model e.g. `hf`") parser.add_argument("--model", "-m", default="hf", help="Name of model e.g. `hf`")
...@@ -149,6 +174,19 @@ def parse_eval_args() -> argparse.Namespace: ...@@ -149,6 +174,19 @@ def parse_eval_args() -> argparse.Namespace:
default=False, default=False,
help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.", help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
) )
parser.add_argument(
"--seed",
type=partial(_int_or_none_list_arg_type, 3),
default="0,1234,1234", # for backward compatibility
help=(
"Set seed for python's random, numpy and torch.\n"
"Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, "
"or a single integer to set the same seed for all three.\n"
"The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility).\n"
"E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`.\n"
"E.g, `--seed 42` sets all three seeds to 42."
),
)
return parser.parse_args() return parser.parse_args()
...@@ -255,6 +293,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -255,6 +293,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
gen_kwargs=args.gen_kwargs, gen_kwargs=args.gen_kwargs,
task_manager=task_manager, task_manager=task_manager,
predict_only=args.predict_only, predict_only=args.predict_only,
random_seed=args.seed[0],
numpy_random_seed=args.seed[1],
torch_random_seed=args.seed[2],
) )
if results is not None: if results is not None:
......
...@@ -40,6 +40,9 @@ def simple_evaluate( ...@@ -40,6 +40,9 @@ def simple_evaluate(
task_manager: TaskManager = None, task_manager: TaskManager = None,
verbosity: str = "INFO", verbosity: str = "INFO",
predict_only: bool = False, predict_only: bool = False,
random_seed: int = 0,
numpy_random_seed: int = 1234,
torch_random_seed: int = 1234,
): ):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
...@@ -75,18 +78,31 @@ def simple_evaluate( ...@@ -75,18 +78,31 @@ def simple_evaluate(
Ignored for all tasks with loglikelihood output_type Ignored for all tasks with loglikelihood output_type
:param predict_only: bool :param predict_only: bool
If true only model outputs will be generated and returned. Metrics will not be evaluated If true only model outputs will be generated and returned. Metrics will not be evaluated
:param random_seed: int
Random seed for python's random module. If set to None, the seed will not be set.
:param numpy_random_seed: int
Random seed for numpy. If set to None, the seed will not be set.
:param torch_random_seed: int
Random seed for torch. If set to None, the seed will not be set.
:return :return
Dictionary of results Dictionary of results
""" """
random.seed(0)
np.random.seed(1234)
torch.manual_seed(
1234
) # TODO: this may affect training runs that are run with evaluation mid-run.
eval_logger.setLevel(getattr(logging, f"{verbosity}")) eval_logger.setLevel(getattr(logging, f"{verbosity}"))
if random_seed is not None:
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
eval_logger.info(f"Setting random seed to {random_seed}")
random.seed(random_seed)
if numpy_random_seed is not None:
eval_logger.info(f"Setting numpy seed to {numpy_random_seed}")
np.random.seed(numpy_random_seed)
if torch_random_seed is not None:
eval_logger.info(f"Setting torch manual seed to {torch_random_seed}")
torch.manual_seed(torch_random_seed)
if tasks is None: if tasks is None:
tasks = [] tasks = []
assert ( assert (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment