Unverified Commit ae72cebc authored by LSinev's avatar LSinev Committed by GitHub
Browse files

Provide ability for custom sampler for ConfigurableTask (#1616)

* Added fewshot sampling seeds to evaluator.simple_evaluate signature

Way to control seed of fewshot sampling
may help with #1591

* Added ability for custom sampler for ConfigurableTask

May be set in config like
```
fewshot_config:
  sampler: !function utils.MyFewshotSampler
```

* explicitly set fewshot random generator seed for HFLM generate_until_task test

* add backward compatibility for three args seed setup

* save seeds info to logs/reports
parent 30c060d2
...@@ -13,7 +13,9 @@ from lm_eval.tasks import TaskManager ...@@ -13,7 +13,9 @@ from lm_eval.tasks import TaskManager
from lm_eval.utils import handle_non_serializable, make_table, simple_parse_args_string from lm_eval.utils import handle_non_serializable, make_table, simple_parse_args_string
def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","): def _int_or_none_list_arg_type(
min_len: int, max_len: int, defaults: str, value: str, split_char: str = ","
):
def parse_value(item): def parse_value(item):
item = item.strip().lower() item = item.strip().lower()
if item == "none": if item == "none":
...@@ -29,10 +31,19 @@ def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","): ...@@ -29,10 +31,19 @@ def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","):
if num_items == 1: if num_items == 1:
# Makes downstream handling the same for single and multiple values # Makes downstream handling the same for single and multiple values
items = items * max_len items = items * max_len
elif num_items != max_len: elif num_items < min_len or num_items > max_len:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError(
f"Argument requires {max_len} integers or None, separated by '{split_char}'" f"Argument requires {max_len} integers or None, separated by '{split_char}'"
) )
elif num_items != max_len:
logging.warning(
f"Argument requires {max_len} integers or None, separated by '{split_char}'. "
"Missing values will be filled with defaults."
)
default_items = [parse_value(v) for v in defaults.split(split_char)]
items.extend(
default_items[num_items:]
) # extend items list with missing defaults
return items return items
...@@ -200,17 +211,20 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -200,17 +211,20 @@ def setup_parser() -> argparse.ArgumentParser:
default=False, default=False,
help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.", help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
) )
default_seed_string = "0,1234,1234,1234"
parser.add_argument( parser.add_argument(
"--seed", "--seed",
type=partial(_int_or_none_list_arg_type, 3), type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
default="0,1234,1234", # for backward compatibility default=default_seed_string, # for backward compatibility
help=( help=(
"Set seed for python's random, numpy and torch.\n" "Set seed for python's random, numpy, torch, and fewshot sampling.\n"
"Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, " "Accepts a comma-separated list of 4 values for python's random, numpy, torch, and fewshot sampling seeds, "
"or a single integer to set the same seed for all three.\n" "respectively, or a single integer to set the same seed for all three.\n"
"The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility).\n" f"The values are either an integer or 'None' to not set the seed. Default is `{default_seed_string}` "
"E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`.\n" "(for backward compatibility).\n"
"E.g, `--seed 42` sets all three seeds to 42." "E.g. `--seed 0,None,8,52` sets `random.seed(0)`, `torch.manual_seed(8)`, and fewshot sampling seed to 52. "
"Here numpy's seed is not set since the second value is `None`.\n"
"E.g, `--seed 42` sets all four seeds to 42."
), ),
) )
parser.add_argument( parser.add_argument(
...@@ -350,6 +364,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -350,6 +364,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
random_seed=args.seed[0], random_seed=args.seed[0],
numpy_random_seed=args.seed[1], numpy_random_seed=args.seed[1],
torch_random_seed=args.seed[2], torch_random_seed=args.seed[2],
fewshot_random_seed=args.seed[3],
**request_caching_args, **request_caching_args,
) )
......
class ContextSampler: class ContextSampler:
def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None: def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
self.rnd = rnd self.rnd = rnd
assert self.rnd, "must pass rnd to FewShotSampler!" if not self.rnd:
raise ValueError(
"A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
)
self.task = task self.task = task
self.config = task._config self.config = task._config
......
...@@ -229,6 +229,9 @@ class Task(abc.ABC): ...@@ -229,6 +229,9 @@ class Task(abc.ABC):
self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig() self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
self._filters = [build_filter_ensemble("none", [["take_first", None]])] self._filters = [build_filter_ensemble("none", [["take_first", None]])]
self.fewshot_rnd: Optional[
random.Random
] = None # purposely induce errors in case of improper usage
def download( def download(
self, self,
...@@ -520,7 +523,7 @@ class Task(abc.ABC): ...@@ -520,7 +523,7 @@ class Task(abc.ABC):
self, self,
doc, doc,
num_fewshot, num_fewshot,
rnd=random.Random(1234), rnd=None,
description=None, description=None,
): ):
"""Returns a fewshot context string that is made up of a prepended description """Returns a fewshot context string that is made up of a prepended description
...@@ -539,9 +542,12 @@ class Task(abc.ABC): ...@@ -539,9 +542,12 @@ class Task(abc.ABC):
The fewshot context. The fewshot context.
""" """
if rnd is None: if rnd is None:
raise ValueError( if self.fewshot_rnd is not None:
"A `random.Random` generator argument must be provided to `rnd`" rnd = self.fewshot_rnd
) else:
raise ValueError(
"A `random.Random` generator argument must be provided to `rnd`"
)
description = description if description else "" description = description if description else ""
...@@ -632,6 +638,11 @@ class Task(abc.ABC): ...@@ -632,6 +638,11 @@ class Task(abc.ABC):
setattr(self._config, "metric_list", [{"metric": metric_name}]) setattr(self._config, "metric_list", [{"metric": metric_name}])
setattr(self._config, "process_results", None) setattr(self._config, "process_results", None)
def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
self.fewshot_rnd = random.Random(seed)
if hasattr(self, "sampler"):
self.sampler.rnd = self.fewshot_rnd
@property @property
def eval_docs(self) -> Union[datasets.Dataset, List[dict]]: def eval_docs(self) -> Union[datasets.Dataset, List[dict]]:
if self.has_test_docs(): if self.has_test_docs():
...@@ -808,11 +819,29 @@ class ConfigurableTask(Task): ...@@ -808,11 +819,29 @@ class ConfigurableTask(Task):
self.prompt = None self.prompt = None
if self.fewshot_docs() is not None: if self.fewshot_docs() is not None:
self.sampler = samplers.get_sampler( self.fewshot_rnd = (
random.Random()
) # setting with no seed, to be overridden at a later time
config_sampler: Union[str, Callable] = (
self.config.fewshot_config.get("sampler", "default") self.config.fewshot_config.get("sampler", "default")
if self.config.fewshot_config if self.config.fewshot_config
else "default" else "default"
)(list(self.fewshot_docs()), self, rnd=random.Random(1234)) )
if isinstance(config_sampler, str):
self.sampler = samplers.get_sampler(config_sampler)(
list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
)
elif callable(config_sampler) and issubclass(
config_sampler, samplers.ContextSampler
):
self.sampler = config_sampler(
docs=list(self.fewshot_docs()), task=self, rnd=self.fewshot_rnd
)
else:
raise TypeError(
f"fewshot_config.sampler should be a string or callable of ContextSampler type, "
f"not {type(config_sampler)}"
)
self.task_docs = self.eval_docs self.task_docs = self.eval_docs
......
...@@ -62,6 +62,7 @@ def simple_evaluate( ...@@ -62,6 +62,7 @@ def simple_evaluate(
random_seed: int = 0, random_seed: int = 0,
numpy_random_seed: int = 1234, numpy_random_seed: int = 1234,
torch_random_seed: int = 1234, torch_random_seed: int = 1234,
fewshot_random_seed: int = 1234,
): ):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
...@@ -109,6 +110,8 @@ def simple_evaluate( ...@@ -109,6 +110,8 @@ def simple_evaluate(
Random seed for numpy. If set to None, the seed will not be set. Random seed for numpy. If set to None, the seed will not be set.
:param torch_random_seed: int :param torch_random_seed: int
Random seed for torch. If set to None, the seed will not be set. Random seed for torch. If set to None, the seed will not be set.
:param fewshot_random_seed: int
Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
:return :return
Dictionary of results Dictionary of results
...@@ -247,6 +250,10 @@ def simple_evaluate( ...@@ -247,6 +250,10 @@ def simple_evaluate(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}" f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
) )
task_obj.set_config(key="num_fewshot", value=num_fewshot) task_obj.set_config(key="num_fewshot", value=num_fewshot)
task_obj.set_fewshot_seed(seed=fewshot_random_seed)
eval_logger.info(
f"Setting fewshot random generator seed to {fewshot_random_seed}"
)
else: else:
# if num_fewshot not provided, and the task does not define a default one, default to 0 # if num_fewshot not provided, and the task does not define a default one, default to 0
if (default_num_fewshot := task_obj.get_config("num_fewshot")) is None: if (default_num_fewshot := task_obj.get_config("num_fewshot")) is None:
...@@ -295,6 +302,10 @@ def simple_evaluate( ...@@ -295,6 +302,10 @@ def simple_evaluate(
"limit": limit, "limit": limit,
"bootstrap_iters": bootstrap_iters, "bootstrap_iters": bootstrap_iters,
"gen_kwargs": gen_kwargs, "gen_kwargs": gen_kwargs,
"random_seed": random_seed,
"numpy_seed": numpy_random_seed,
"torch_seed": torch_random_seed,
"fewshot_seed": fewshot_random_seed,
} }
) )
results["git_hash"] = get_git_commit_hash() results["git_hash"] = get_git_commit_hash()
......
...@@ -23,6 +23,7 @@ class Test_HFLM: ...@@ -23,6 +23,7 @@ class Test_HFLM:
MULTIPLE_CH: list[Instance] = multiple_choice_task.instances MULTIPLE_CH: list[Instance] = multiple_choice_task.instances
generate_until_task = task_list["gsm8k"] # type: ignore generate_until_task = task_list["gsm8k"] # type: ignore
generate_until_task._config.generation_kwargs["max_gen_toks"] = 10 generate_until_task._config.generation_kwargs["max_gen_toks"] = 10
generate_until_task.set_fewshot_seed(1234) # fewshot random generator seed
generate_until_task.build_all_requests(limit=10, rank=0, world_size=1) generate_until_task.build_all_requests(limit=10, rank=0, world_size=1)
generate_until: list[Instance] = generate_until_task.instances generate_until: list[Instance] = generate_until_task.instances
rolling_task = task_list["wikitext"] # type: ignore rolling_task = task_list["wikitext"] # type: ignore
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment