Unverified Commit f8203de1 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

add bypass metric (#1156)

* add bypass metric

* fixed `bypass` metric.

* add task attributes if predict_only

* add `predict_only` checks

* add docs

* added `overide_metric`, `override_config` to `Task`

* nits

* nit

* changed --predict_only to generations; nits

* nits

* nits

* change gen_kwargs warning

* add note about `--predict_only` in README.md

* added `predict_only`

* move table to bottom

* nit

* change null aggregation to bypass (conflict)

* bugfix; default `temp=0.0`

* typo
parent 084b7050
...@@ -45,27 +45,7 @@ git clone https://github.com/EleutherAI/lm-evaluation-harness ...@@ -45,27 +45,7 @@ git clone https://github.com/EleutherAI/lm-evaluation-harness
cd lm-evaluation-harness cd lm-evaluation-harness
pip install -e . pip install -e .
``` ```
We also provide a number of optional dependencies for extended functionality. A detailed table is available at the end of this document.
We also provide a number of optional dependencies for extended functionality. Extras can be installed via `pip install -e ".[NAME]"`
| Name | Use |
|---------------|---------------------------------------|
| anthropic | For using Anthropic's models |
| dev | For linting PRs and contributions |
| gptq | For loading models with GPTQ |
| ifeval | For running the IFEval task |
| mamba | For loading Mamba SSM models |
| math | For running math task answer checking |
| multilingual | For multilingual tokenizers |
| openai | For using OpenAI's models |
| optimum | For running Intel OpenVINO models |
| promptsource | For using PromptSource prompts |
| sentencepiece | For using the sentencepiece tokenizer |
| testing | For running library test suite |
| vllm | For loading models with vLLM |
| zeno | For visualizing results with Zeno |
|---------------|---------------------------------------|
| all | Loads all extras (not recommended) |
## Basic Usage ## Basic Usage
...@@ -204,6 +184,8 @@ A number of other libraries contain scripts for calling the eval harness through ...@@ -204,6 +184,8 @@ A number of other libraries contain scripts for calling the eval harness through
To create your own custom integration you can follow instructions from [this tutorial](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage). To create your own custom integration you can follow instructions from [this tutorial](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage).
### Additional Features ### Additional Features
> [!Note]
> For tasks unsuitable for direct evaluation — either due risks associated with executing untrusted code or complexities in the evaluation process — the `--predict_only` flag is available to obtain decoded generations for post-hoc evaluation.
If you have a Metal compatible Mac, you can run the eval harness using the MPS back-end by replacing `--device cuda:0` with `--device mps` (requires PyTorch version 2.1 or higher). If you have a Metal compatible Mac, you can run the eval harness using the MPS back-end by replacing `--device cuda:0` with `--device mps` (requires PyTorch version 2.1 or higher).
...@@ -254,7 +236,7 @@ Additionally, one can provide a directory with `--use_cache` to cache the result ...@@ -254,7 +236,7 @@ Additionally, one can provide a directory with `--use_cache` to cache the result
For a full list of supported arguments, check out the [interface](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md) guide in our documentation! For a full list of supported arguments, check out the [interface](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md) guide in our documentation!
> [!Tip] > [!Tip]
> Running lm-evaluation-harness as an external library and can't find (almost) any tasks available? run `lm_eval.tasks.initialize_tasks()` to load the library's stock tasks before calling `lm_eval.evaluate()` or `lm_eval.simple_evaluate()` ! > Running lm-evaluation-harness as an external library and can't find (almost) any tasks available? Run `lm_eval.tasks.initialize_tasks()` to load the library's stock tasks before calling `lm_eval.evaluate()` or `lm_eval.simple_evaluate()` !
## Visualizing Results ## Visualizing Results
...@@ -319,6 +301,28 @@ We try to prioritize agreement with the procedures used by other groups to decre ...@@ -319,6 +301,28 @@ We try to prioritize agreement with the procedures used by other groups to decre
The best way to get support is to open an issue on this repo or join the [EleutherAI Discord server](https://discord.gg/eleutherai). The `#lm-thunderdome` channel is dedicated to developing this project and the `#release-discussion` channel is for receiving support for our releases. If you've used the library and have had a positive (or negative) experience, we'd love to hear from you! The best way to get support is to open an issue on this repo or join the [EleutherAI Discord server](https://discord.gg/eleutherai). The `#lm-thunderdome` channel is dedicated to developing this project and the `#release-discussion` channel is for receiving support for our releases. If you've used the library and have had a positive (or negative) experience, we'd love to hear from you!
## Optional Extras
Extras dependencies can be installed via `pip install -e ".[NAME]"`
| Name | Use |
|---------------|---------------------------------------|
| anthropic | For using Anthropic's models |
| dev | For linting PRs and contributions |
| gptq | For loading models with GPTQ |
| ifeval | For running the IFEval task |
| mamba | For loading Mamba SSM models |
| math | For running math task answer checking |
| multilingual | For multilingual tokenizers |
| openai | For using OpenAI's models |
| optimum | For running Intel OpenVINO models |
| promptsource | For using PromptSource prompts |
| sentencepiece | For using the sentencepiece tokenizer |
| testing | For running library test suite |
| vllm | For loading models with vLLM |
| zeno | For visualizing results with Zeno |
|---------------|---------------------------------------|
| all | Loads all extras (not recommended) |
## Cite as ## Cite as
``` ```
......
...@@ -44,6 +44,8 @@ This mode supports a number of command-line arguments, the details of which can ...@@ -44,6 +44,8 @@ This mode supports a number of command-line arguments, the details of which can
* `--include_path` : Accepts a path to a folder. If passed, then all YAML files containing `lm-eval`` compatible task configurations will be added to the task registry as available tasks. Used for when one is writing config files for their own task in a folder other than `lm_eval/tasks/` * `--include_path` : Accepts a path to a folder. If passed, then all YAML files containing `lm-eval`` compatible task configurations will be added to the task registry as available tasks. Used for when one is writing config files for their own task in a folder other than `lm_eval/tasks/`
* `--predict_only`: Generates the model outputs without computing metrics. Use with `--log_samples` to retrieve decoded results.
## External Library Usage ## External Library Usage
We also support using the library's external API for use within model training loops or other scripts. We also support using the library's external API for use within model training loops or other scripts.
......
...@@ -143,6 +143,13 @@ def parse_eval_args() -> argparse.Namespace: ...@@ -143,6 +143,13 @@ def parse_eval_args() -> argparse.Namespace:
metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG", metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
help="Controls the reported logging error level. Set to DEBUG when testing + adding new task configurations for comprehensive log output.", help="Controls the reported logging error level. Set to DEBUG when testing + adding new task configurations for comprehensive log output.",
) )
parser.add_argument(
"--predict_only",
"-x",
action="store_true",
default=False,
help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
)
return parser.parse_args() return parser.parse_args()
...@@ -156,6 +163,11 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -156,6 +163,11 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
eval_logger.info(f"Verbosity set to {args.verbosity}") eval_logger.info(f"Verbosity set to {args.verbosity}")
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
if args.predict_only:
args.log_samples = True
if (args.log_samples or args.predict_only) and not args.output_path:
assert args.output_path, "Specify --output_path"
initialize_tasks(args.verbosity) initialize_tasks(args.verbosity)
if args.limit: if args.limit:
...@@ -223,8 +235,6 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -223,8 +235,6 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
else: else:
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
output_path_file = path.joinpath("results.json") output_path_file = path.joinpath("results.json")
elif args.log_samples and not args.output_path:
assert args.output_path, "Specify --output_path"
eval_logger.info(f"Selected Tasks: {task_names}") eval_logger.info(f"Selected Tasks: {task_names}")
...@@ -243,6 +253,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -243,6 +253,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
write_out=args.write_out, write_out=args.write_out,
log_samples=args.log_samples, log_samples=args.log_samples,
gen_kwargs=args.gen_kwargs, gen_kwargs=args.gen_kwargs,
predict_only=args.predict_only,
) )
if results is not None: if results is not None:
......
...@@ -15,6 +15,11 @@ eval_logger = logging.getLogger("lm-eval") ...@@ -15,6 +15,11 @@ eval_logger = logging.getLogger("lm-eval")
# Register Aggregations First # Register Aggregations First
@register_aggregation("bypass")
def bypass_agg(arr):
return 999
@register_aggregation("mean") @register_aggregation("mean")
def mean(arr): def mean(arr):
return sum(arr) / len(arr) return sum(arr) / len(arr)
...@@ -207,6 +212,16 @@ def mean_stderr(arr): ...@@ -207,6 +212,16 @@ def mean_stderr(arr):
return sample_stddev(arr) / math.sqrt(len(arr)) return sample_stddev(arr) / math.sqrt(len(arr))
@register_metric(
metric="bypass",
higher_is_better=True,
output_type=["loglikelihood", "multiple_choice", "generate_until"],
aggregation="bypass",
)
def bypass(items):
return None
@register_metric( @register_metric(
metric="mcc", metric="mcc",
higher_is_better=True, higher_is_better=True,
......
...@@ -1213,12 +1213,46 @@ class ConfigurableTask(Task): ...@@ -1213,12 +1213,46 @@ class ConfigurableTask(Task):
return result_dict return result_dict
def aggregation(self): def aggregation(self) -> dict:
return self._aggregation_list return self._aggregation_list
def higher_is_better(self): def higher_is_better(self) -> dict:
return self._higher_is_better return self._higher_is_better
def get_config(self, key: str) -> Any:
return getattr(self._config, key, None)
def override_metric(self, metric_name: str) -> None:
"""
Override the default metrics used for evaluation with custom metrics.
Parameters:
- metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
"""
(
self._metric_fn_list,
self._aggregation_list,
self._metric_fn_kwargs,
self._higher_is_better,
) = ({}, {}, {}, {})
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
self._metric_fn_kwargs[metric_name] = {}
setattr(self._config, "metric_list", [{"metric": metric_name}])
setattr(self._config, "process_results", None)
def override_config(
self, key: str = None, value: Any = None, update: bool = False
) -> None:
if update:
current_value = getattr(self._config, key)
assert isinstance(current_value, dict)
current_value.update(value)
setattr(self._config, key, current_value)
else:
setattr(self._config, key, value)
class MultipleChoiceTask(Task): class MultipleChoiceTask(Task):
OUTPUT_TYPE: str = "loglikelihood" OUTPUT_TYPE: str = "loglikelihood"
......
...@@ -38,6 +38,7 @@ def simple_evaluate( ...@@ -38,6 +38,7 @@ def simple_evaluate(
write_out: bool = False, write_out: bool = False,
log_samples: bool = True, log_samples: bool = True,
gen_kwargs: str = None, gen_kwargs: str = None,
predict_only: bool = False,
): ):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
...@@ -71,6 +72,9 @@ def simple_evaluate( ...@@ -71,6 +72,9 @@ def simple_evaluate(
:param gen_kwargs: str :param gen_kwargs: str
String arguments for model generation String arguments for model generation
Ignored for all tasks with loglikelihood output_type Ignored for all tasks with loglikelihood output_type
:param predict_only: bool
If true only model outputs will be generated and returned. Metrics will not be evaluated
:return :return
Dictionary of results Dictionary of results
""" """
...@@ -89,7 +93,7 @@ def simple_evaluate( ...@@ -89,7 +93,7 @@ def simple_evaluate(
if gen_kwargs is not None: if gen_kwargs is not None:
gen_kwargs = simple_parse_args_string(gen_kwargs) gen_kwargs = simple_parse_args_string(gen_kwargs)
eval_logger.warning( eval_logger.warning(
"generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks." "generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. Ensure 'do_sample=True' for non-greedy decoding!"
) )
if gen_kwargs == "": if gen_kwargs == "":
gen_kwargs = None gen_kwargs = None
...@@ -129,25 +133,30 @@ def simple_evaluate( ...@@ -129,25 +133,30 @@ def simple_evaluate(
if task_obj is None: if task_obj is None:
continue continue
config = task_obj._config if task_obj.get_config("output_type") == "generate_until":
if config["output_type"] == "generate_until" and gen_kwargs is not None: if gen_kwargs is not None:
config["generation_kwargs"].update(gen_kwargs) task_obj.override_config(
key="generation_kwargs", value=gen_kwargs, update=True
)
if predict_only:
log_samples = True
eval_logger.info(
f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
)
# we have to change the class properties post-hoc. This is pretty hacky.
task_obj.override_metric(metric_name="bypass")
if num_fewshot is not None: if num_fewshot is not None:
if config["num_fewshot"] == 0: if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
eval_logger.info( eval_logger.info(
f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored." f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
) )
else: else:
default_num_fewshot = config["num_fewshot"]
if default_num_fewshot:
# warn a user, if a specific num_fewshot > 0 was specified.
# if unspecified in config, no warning message
eval_logger.warning( eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}" f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
) )
task_obj.override_config(key="num_fewshot", value=num_fewshot)
task_obj._config["num_fewshot"] = num_fewshot
if check_integrity: if check_integrity:
run_task_tests(task_list=tasks) run_task_tests(task_list=tasks)
...@@ -223,6 +232,14 @@ def evaluate( ...@@ -223,6 +232,14 @@ def evaluate(
# decontaminate = decontamination_ngrams_path is not None # decontaminate = decontamination_ngrams_path is not None
for task_name, task in task_dict.items():
if isinstance(task, tuple):
_, task = task
if not log_samples:
assert (
"bypass" not in getattr(task, "_metric_fn_list", {}).keys()
), f"log_samples must be True for 'bypass' only tasks: {task_name}"
# stores the final result for each task, for each metric/filter pair. # stores the final result for each task, for each metric/filter pair.
results = collections.defaultdict(dict) results = collections.defaultdict(dict)
# Tracks each task's version. # Tracks each task's version.
...@@ -479,7 +496,8 @@ def evaluate( ...@@ -479,7 +496,8 @@ def evaluate(
if bool(results): if bool(results):
for group, task_list in reversed(task_hierarchy.items()): for group, task_list in reversed(task_hierarchy.items()):
if task_list == []: if task_list == []:
total_size = results[group]["samples"] # TODO: No samples when bypass
total_size = results[group].get("samples", 999)
else: else:
total_size = 0 total_size = 0
......
...@@ -715,12 +715,14 @@ class HFLM(LM): ...@@ -715,12 +715,14 @@ class HFLM(LM):
return self.model(inps).logits return self.model(inps).logits
def _model_generate(self, context, max_length, stop, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
# temperature = 0.0 if not set
# if do_sample is false and temp==0.0: # if do_sample is false and temp==0.0:
# remove temperature, as do_sample=False takes care of this # remove temperature, as do_sample=False takes care of this
# and we don't want a warning from HF # and we don't want a warning from HF
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
do_sample = generation_kwargs.get("do_sample", None) do_sample = generation_kwargs.get("do_sample", None)
if do_sample is False and "temperature" == 0.0: if do_sample is False and generation_kwargs.get("temperature") == 0.0:
generation_kwargs.pop("temperature", 0.0) generation_kwargs.pop("temperature")
# build stopping criteria # build stopping criteria
stopping_criteria = stop_sequences_criteria( stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, context.shape[1], context.shape[0] self.tokenizer, stop, context.shape[1], context.shape[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment