Unverified Commit ae72cebc authored by LSinev's avatar LSinev Committed by GitHub
Browse files

Provide ability for custom sampler for ConfigurableTask (#1616)

* Added fewshot sampling seeds to evaluator.simple_evaluate signature

Way to control seed of fewshot sampling
may help with #1591

* Added ability for custom sampler for ConfigurableTask

May be set in config like
```
fewshot_config:
  sampler: !function utils.MyFewshotSampler
```

* explicitly set fewshot random generator seed for HFLM generate_until_task test

* add backward compatibility for three args seed setup

* save seeds info to logs/reports
parent 30c060d2
......@@ -13,7 +13,9 @@ from lm_eval.tasks import TaskManager
from lm_eval.utils import handle_non_serializable, make_table, simple_parse_args_string
def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","):
def _int_or_none_list_arg_type(
min_len: int, max_len: int, defaults: str, value: str, split_char: str = ","
):
def parse_value(item):
item = item.strip().lower()
if item == "none":
......@@ -29,10 +31,19 @@ def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","):
if num_items == 1:
# Makes downstream handling the same for single and multiple values
items = items * max_len
elif num_items != max_len:
elif num_items < min_len or num_items > max_len:
raise argparse.ArgumentTypeError(
f"Argument requires {max_len} integers or None, separated by '{split_char}'"
)
elif num_items != max_len:
logging.warning(
f"Argument requires {max_len} integers or None, separated by '{split_char}'. "
"Missing values will be filled with defaults."
)
default_items = [parse_value(v) for v in defaults.split(split_char)]
items.extend(
default_items[num_items:]
) # extend items list with missing defaults
return items
......@@ -200,17 +211,20 @@ def setup_parser() -> argparse.ArgumentParser:
default=False,
help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
)
default_seed_string = "0,1234,1234,1234"
parser.add_argument(
"--seed",
type=partial(_int_or_none_list_arg_type, 3),
default="0,1234,1234", # for backward compatibility
type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
default=default_seed_string, # for backward compatibility
help=(
"Set seed for python's random, numpy and torch.\n"
"Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, "
"or a single integer to set the same seed for all three.\n"
"The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility).\n"
"E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`.\n"
"E.g, `--seed 42` sets all three seeds to 42."
"Set seed for python's random, numpy, torch, and fewshot sampling.\n"
"Accepts a comma-separated list of 4 values for python's random, numpy, torch, and fewshot sampling seeds, "
"respectively, or a single integer to set the same seed for all three.\n"
f"The values are either an integer or 'None' to not set the seed. Default is `{default_seed_string}` "
"(for backward compatibility).\n"
"E.g. `--seed 0,None,8,52` sets `random.seed(0)`, `torch.manual_seed(8)`, and fewshot sampling seed to 52. "
"Here numpy's seed is not set since the second value is `None`.\n"
"E.g, `--seed 42` sets all four seeds to 42."
),
)
parser.add_argument(
......@@ -350,6 +364,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
random_seed=args.seed[0],
numpy_random_seed=args.seed[1],
torch_random_seed=args.seed[2],
fewshot_random_seed=args.seed[3],
**request_caching_args,
)
......
class ContextSampler:
def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
self.rnd = rnd
assert self.rnd, "must pass rnd to FewShotSampler!"
if not self.rnd:
raise ValueError(
"A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
)
self.task = task
self.config = task._config
......
......@@ -229,6 +229,9 @@ class Task(abc.ABC):
self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
self.fewshot_rnd: Optional[
random.Random
] = None # purposely induce errors in case of improper usage
def download(
self,
......@@ -520,7 +523,7 @@ class Task(abc.ABC):
self,
doc,
num_fewshot,
rnd=random.Random(1234),
rnd=None,
description=None,
):
"""Returns a fewshot context string that is made up of a prepended description
......@@ -539,9 +542,12 @@ class Task(abc.ABC):
The fewshot context.
"""
if rnd is None:
raise ValueError(
"A `random.Random` generator argument must be provided to `rnd`"
)
if self.fewshot_rnd is not None:
rnd = self.fewshot_rnd
else:
raise ValueError(
"A `random.Random` generator argument must be provided to `rnd`"
)
description = description if description else ""
......@@ -632,6 +638,11 @@ class Task(abc.ABC):
setattr(self._config, "metric_list", [{"metric": metric_name}])
setattr(self._config, "process_results", None)
def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
self.fewshot_rnd = random.Random(seed)
if hasattr(self, "sampler"):
self.sampler.rnd = self.fewshot_rnd
@property
def eval_docs(self) -> Union[datasets.Dataset, List[dict]]:
if self.has_test_docs():
......@@ -808,11 +819,29 @@ class ConfigurableTask(Task):
self.prompt = None
if self.fewshot_docs() is not None:
self.sampler = samplers.get_sampler(
self.fewshot_rnd = (
random.Random()
) # setting with no seed, to be overridden at a later time
config_sampler: Union[str, Callable] = (
self.config.fewshot_config.get("sampler", "default")
if self.config.fewshot_config
else "default"
)(list(self.fewshot_docs()), self, rnd=random.Random(1234))
)
if isinstance(config_sampler, str):
self.sampler = samplers.get_sampler(config_sampler)(
list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
)
elif callable(config_sampler) and issubclass(
config_sampler, samplers.ContextSampler
):
self.sampler = config_sampler(
docs=list(self.fewshot_docs()), task=self, rnd=self.fewshot_rnd
)
else:
raise TypeError(
f"fewshot_config.sampler should be a string or callable of ContextSampler type, "
f"not {type(config_sampler)}"
)
self.task_docs = self.eval_docs
......
......@@ -62,6 +62,7 @@ def simple_evaluate(
random_seed: int = 0,
numpy_random_seed: int = 1234,
torch_random_seed: int = 1234,
fewshot_random_seed: int = 1234,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -109,6 +110,8 @@ def simple_evaluate(
Random seed for numpy. If set to None, the seed will not be set.
:param torch_random_seed: int
Random seed for torch. If set to None, the seed will not be set.
:param fewshot_random_seed: int
Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
:return
Dictionary of results
......@@ -247,6 +250,10 @@ def simple_evaluate(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
)
task_obj.set_config(key="num_fewshot", value=num_fewshot)
task_obj.set_fewshot_seed(seed=fewshot_random_seed)
eval_logger.info(
f"Setting fewshot random generator seed to {fewshot_random_seed}"
)
else:
# if num_fewshot not provided, and the task does not define a default one, default to 0
if (default_num_fewshot := task_obj.get_config("num_fewshot")) is None:
......@@ -295,6 +302,10 @@ def simple_evaluate(
"limit": limit,
"bootstrap_iters": bootstrap_iters,
"gen_kwargs": gen_kwargs,
"random_seed": random_seed,
"numpy_seed": numpy_random_seed,
"torch_seed": torch_random_seed,
"fewshot_seed": fewshot_random_seed,
}
)
results["git_hash"] = get_git_commit_hash()
......
......@@ -23,6 +23,7 @@ class Test_HFLM:
MULTIPLE_CH: list[Instance] = multiple_choice_task.instances
generate_until_task = task_list["gsm8k"] # type: ignore
generate_until_task._config.generation_kwargs["max_gen_toks"] = 10
generate_until_task.set_fewshot_seed(1234) # fewshot random generator seed
generate_until_task.build_all_requests(limit=10, rank=0, world_size=1)
generate_until: list[Instance] = generate_until_task.instances
rolling_task = task_list["wikitext"] # type: ignore
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment