Commit abd17276 authored by Baber's avatar Baber
Browse files

Merge branch 'smolrefact' into tasklist

# Conflicts:
#	lm_eval/__main__.py
#	lm_eval/api/group.py
#	lm_eval/api/task.py
#	lm_eval/evaluator_utils.py
#	lm_eval/tasks/__init__.py
#	lm_eval/utils.py
#	pyproject.toml
parents 00afd536 70314843
......@@ -32,10 +32,8 @@ repos:
rev: v0.12.5
hooks:
# Run the linter.
- id: ruff
args:
- --fix
# Run the formatter.
- id: ruff-check
args: [--fix]
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
rev: v2.4.1
......
......@@ -8,71 +8,160 @@ A majority of users run the library by cloning it from Github, installing the pa
Equivalently, running the library can be done via the `lm-eval` entrypoint at the command line.
This mode supports a number of command-line arguments, the details of which can also be seen via running with `-h` or `--help`:
### Subcommand Structure
- `--model` : Selects which model type or provider is evaluated. Must be a string corresponding to the name of the model type/provider being used. See [the main README](https://github.com/EleutherAI/lm-evaluation-harness/tree/main#model-apis-and-inference-servers) for a full list of enabled model names and supported libraries or APIs.
The CLI now uses a subcommand structure for better organization:
- `--model_args` : Controls parameters passed to the model constructor. Accepts a string containing comma-separated keyword arguments to the model class of the format `"arg1=val1,arg2=val2,..."`, such as, for example `--model_args pretrained=EleutherAI/pythia-160m,dtype=float32`. For a full list of what keyword arguments, see the initialization of the `lm_eval.api.model.LM` subclass, e.g. [`HFLM`](https://github.com/EleutherAI/lm-evaluation-harness/blob/365fcda9b85bbb6e0572d91976b8daf409164500/lm_eval/models/huggingface.py#L66)
- `lm-eval run` - Execute evaluations (default behavior)
- `lm-eval ls` - List available tasks, models, etc.
- `lm-eval validate` - Validate task configurations
- `--tasks` : Determines which tasks or task groups are evaluated. Accepts a comma-separated list of task names or task group names. Must be solely comprised of valid tasks/groups. A list of supported tasks can be viewed with `--tasks list`.
For backward compatibility, if no subcommand is specified, `run` is automatically inserted. So `lm-eval --model hf --tasks hellaswag` is equivalent to `lm-eval run --model hf --tasks hellaswag`.
- `--num_fewshot` : Sets the number of few-shot examples to place in context. Must be an integer.
### Run Command Arguments
- `--gen_kwargs` : takes an arg string in same format as `--model_args` and creates a dictionary of keyword arguments. These will be passed to the models for all called `generate_until` (free-form or greedy generation task) tasks, to set options such as the sampling temperature or `top_p` / `top_k`. For a list of what args are supported for each model type, reference the respective library's documentation (for example, the documentation for `transformers.AutoModelForCausalLM.generate()`.) These kwargs will be applied to all `generate_until` tasks called--we do not currently support unique gen_kwargs or batch_size values per task in a single run of the library. To control these on a per-task level, set them in that task's YAML file.
The `run` command supports a number of command-line arguments. Details can also be seen via running with `-h` or `--help`:
- `--batch_size` : Sets the batch size used for evaluation. Can be a positive integer or `"auto"` to automatically select the largest batch size that will fit in memory, speeding up evaluation. One can pass `--batch_size auto:N` to re-select the maximum batch size `N` times during evaluation. This can help accelerate evaluation further, since `lm-eval` sorts documents in descending order of context length.
#### Configuration
- `--max_batch_size` : Sets the maximum batch size to try to fit in memory, if `--batch_size auto` is passed.
- `--config` **[path: str]** : Set initial arguments from a YAML configuration file. Takes a path to a YAML file that contains argument values. This allows you to specify complex configurations in a file rather than on the command line. Further CLI arguments can override values from the configuration file.
- `--device` : Sets which device to place the model onto. Must be a string, for example, `"cuda", "cuda:0", "cpu", "mps"`. Defaults to "cuda", and can be ignored if running multi-GPU or running a non-local model type.
For the complete list of available configuration fields and their types, see [`EvaluatorConfig` in the source code](../lm_eval/config/evaluate_config.py).
- `--output_path` : A string of the form `dir/file.jsonl` or `dir/`. Provides a path where high-level results will be saved, either into the file named or into the directory named. If `--log_samples` is passed as well, then per-document outputs and metrics will be saved into the directory as well.
#### Model and Tasks
- `--log_samples` : If this flag is passed, then the model's outputs, and the text fed into the model, will be saved at per-document granularity. Must be used with `--output_path`.
- `--model` **[str, default: "hf"]** : Selects which model type or provider is evaluated. Must be a string corresponding to the name of the model type/provider being used. See [the main README](https://github.com/EleutherAI/lm-evaluation-harness/tree/main#model-apis-and-inference-servers) for a full list of enabled model names and supported libraries or APIs.
- `--limit` : Accepts an integer, or a float between 0.0 and 1.0 . If passed, will limit the number of documents to evaluate to the first X documents (if an integer) per task or first X% of documents per task. Useful for debugging, especially on costly API models.
- `--model_args` **[comma-sep str | json str → dict]** : Controls parameters passed to the model constructor. Can be provided as:
- Comma-separated string: `pretrained=EleutherAI/pythia-160m,dtype=float32`
- JSON string: `'{"pretrained": "EleutherAI/pythia-160m", "dtype": "float32"}'`
- `--use_cache` : Should be a path where a sqlite db file can be written to. Takes a string of format `/path/to/sqlite_cache_` in order to create a cache db at `/path/to/sqlite_cache_rank{i}.db` for each process (0-NUM_GPUS). This allows results of prior runs to be cached, so that there is no need to re-run results in order to re-score or re-run a given (model, task) pair again.
For a full list of supported arguments, see the initialization of the `lm_eval.api.model.LM` subclass, e.g. [`HFLM`](https://github.com/EleutherAI/lm-evaluation-harness/blob/365fcda9b85bbb6e0572d91976b8daf409164500/lm_eval/models/huggingface.py#L66)
- `--cache_requests` : Can be "true", "refresh", or "delete". "true" means that the cache should be used. "refresh" means that you wish to regenerate the cache, which you should run if you change your dataset configuration for a given task. "delete" will delete the cache. Cached files are stored under lm_eval/cache/.cache unless you specify a different path via the environment variable: `LM_HARNESS_CACHE_PATH`. e.g. `LM_HARNESS_CACHE_PATH=~/Documents/cache_for_lm_harness`.
- `--tasks` **[comma-sep str → list[str]]** : Determines which tasks or task groups are evaluated. Accepts a comma-separated list of task names or task group names. Must be solely comprised of valid tasks/groups. A list of supported tasks can be viewed with `lm-eval list tasks`.
- `--check_integrity` : If this flag is used, the library tests for each task selected are run to confirm task integrity.
#### Evaluation Settings
- `--write_out` : Used for diagnostic purposes to observe the format of task documents passed to a model. If this flag is used, then prints the prompt and gold target string for the first document of each task.
- `--num_fewshot` **[int]** : Sets the number of few-shot examples to place in context. Must be an integer.
- `--show_config` : If used, prints the full `lm_eval.api.task.TaskConfig` contents (non-default settings the task YAML file) for each task which was run, at the completion of an evaluation. Useful for when one is modifying a task's configuration YAML locally to transmit the exact configurations used for debugging or for reproducibility purposes.
- `--batch_size` **[int | "auto" | "auto:N", default: 1]** : Sets the batch size used for evaluation. Options:
- Integer: Fixed batch size (e.g., `8`)
- `"auto"`: Automatically select the largest batch size that fits in memory
- `"auto:N"`: Re-select maximum batch size N times during evaluation
- `--include_path` : Accepts a path to a folder. If passed, then all YAML files containing `lm-eval` compatible task configurations will be added to the task registry as available tasks. Used for when one is writing config files for their own task in a folder other than `lm_eval/tasks/`.
Auto mode is useful since `lm-eval` sorts documents in descending order of context length.
- `--system_instruction`: Specifies a system instruction string to prepend to the prompt.
- `--max_batch_size` **[int]** : Sets the maximum batch size to try when using `--batch_size auto`.
- `--apply_chat_template` : This flag specifies whether to apply a chat template to the prompt. It can be used in the following ways:
- `--apply_chat_template` : When used without an argument, applies the only available chat template to the prompt. For Hugging Face models, if no dedicated chat template exists, the default chat template will be applied.
- `--apply_chat_template template_name` : If the model has multiple chat templates, apply the specified template to the prompt.
- `--device` **[str]** : Sets which device to place the model onto. Examples: `"cuda"`, `"cuda:0"`, `"cpu"`, `"mps"`. Can be ignored if running multi-GPU or non-local model types.
For Hugging Face models, the default chat template can be found in the [`default_chat_template`](https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1912) property of the Transformers Tokenizer.
- `--gen_kwargs` **[comma-sep str | json str → dict]** : Generation arguments for `generate_until` tasks. Same format as `--model_args`:
- Comma-separated: `temperature=0.8,top_p=0.95`
- JSON: `'{"temperature": 0.8, "top_p": 0.95}'`
- `--fewshot_as_multiturn` : If this flag is on, the Fewshot examples are treated as a multi-turn conversation. Questions are provided as user content and answers are provided as assistant responses. Requires `--num_fewshot` to be set to be greater than 0, and `--apply_chat_template` to be on.
See model documentation (e.g., `transformers.AutoModelForCausalLM.generate()`) for supported arguments. Applied to all generation tasks - use task YAML files for per-task control.
- `--predict_only`: Generates the model outputs without computing metrics. Use with `--log_samples` to retrieve decoded results.
#### Data and Output
- `--seed`: Set seed for python's random, numpy and torch. Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, or a single integer to set the same seed for all three. The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility). E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`. E.g, `--seed 42` sets all three seeds to 42.
- `--output_path` **[path: str]** : Output location for results. Format options:
- Directory: `results/` - saves as `results/<model_name>_<timestamp>.json`
- File: `results/output.jsonl` - saves to specific file
- `--wandb_args`: Tracks logging to Weights and Biases for evaluation runs and includes args passed to `wandb.init`, such as `project` and `job_type`. Full list [here](https://docs.wandb.ai/ref/python/init). e.g., ```--wandb_args project=test-project,name=test-run```. Also allows for the passing of the step to log things at (passed to `wandb.run.log`), e.g., `--wandb_args step=123`.
When used with `--log_samples`, per-document outputs are saved in the directory.
- `--hf_hub_log_args` : Logs evaluation results to Hugging Face Hub. Accepts a string with the arguments separated by commas. Available arguments:
- `hub_results_org` - organization name on Hugging Face Hub, e.g., `EleutherAI`. If not provided, the results will be pushed to the owner of the Hugging Face token,
- `hub_repo_name` - repository name on Hugging Face Hub (deprecated, `details_repo_name` and `results_repo_name` should be used instead), e.g., `lm-eval-results`,
- `details_repo_name` - repository name on Hugging Face Hub to store details, e.g., `lm-eval-results`,
- `results_repo_name` - repository name on Hugging Face Hub to store results, e.g., `lm-eval-results`,
- `push_results_to_hub` - whether to push results to Hugging Face Hub, can be `True` or `False`,
- `push_samples_to_hub` - whether to push samples results to Hugging Face Hub, can be `True` or `False`. Requires `--log_samples` to be set,
- `public_repo` - whether the repository is public, can be `True` or `False`,
- `leaderboard_url` - URL to the leaderboard, e.g., `https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard`.
- `point_of_contact` - Point of contact for the results dataset, e.g., `yourname@example.com`.
- `gated` - whether to gate the details dataset, can be `True` or `False`.
- `--log_samples` **[flag, default: False]** : Save model outputs and inputs at per-document granularity. Requires `--output_path`. Automatically enabled when using `--predict_only`.
- `--metadata`: JSON string to pass to TaskConfig. Used for some tasks which require additional metadata to be passed for processing. E.g., `--metadata '{"key": "value"}'`.
- `--limit` **[int | float]** : Limit evaluation examples per task. **WARNING: Only for testing!**
- Integer: First N documents (e.g., `100`)
- Float (0.0-1.0): Percentage of documents (e.g., `0.1` for 10%)
- `--samples` **[path | json str | dict → dict]** : Evaluate specific sample indices only. Input formats:
- JSON file path: `samples.json`
- JSON string: `'{"hellaswag": [0, 1, 2], "arc_easy": [10, 20]}'`
- Dictionary (programmatic use)
Format: `{"task_name": [indices], ...}`. Incompatible with `--limit`.
#### Caching and Performance
- `--use_cache` **[path: str]** : SQLite cache database path prefix. Creates per-process cache files:
- Single GPU: `/path/to/cache.db`
- Multi-GPU: `/path/to/cache_rank0.db`, `/path/to/cache_rank1.db`, etc.
Caches model outputs to avoid re-running the same (model, task) evaluations.
- `--cache_requests` **["true" | "refresh" | "delete"]** : Dataset request caching control:
- `"true"`: Use existing cache
- `"refresh"`: Regenerate cache (use after changing task configs)
- `"delete"`: Delete cache
Cache location: `lm_eval/cache/.cache` or `$LM_HARNESS_CACHE_PATH` if set.
- `--check_integrity` **[flag, default: False]** : Run task integrity tests to validate configurations.
#### Instruct Formatting
- `--system_instruction` **[str]** : Custom system instruction to prepend to prompts. Used with instruction-following models.
- `--apply_chat_template` **[bool | str, default: False]** : Apply chat template formatting. Usage:
- No argument: Apply default/only available template
- Template name: Apply specific template (e.g., `"chatml"`)
For HuggingFace models, uses the tokenizer's chat template. Default template defined in [`transformers` documentation](https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1912).
- `--fewshot_as_multiturn` **[flag, default: False]** : Format few-shot examples as multi-turn conversation:
- Questions → User messages
- Answers → Assistant responses
Requires: `--num_fewshot > 0` and `--apply_chat_template` enabled.
#### Task Management
- `--include_path` **[path: str]** : Directory containing custom task YAML files. All `.yaml` files in this directory will be registered as available tasks. Use for custom tasks outside of `lm_eval/tasks/`.
#### Logging and Tracking
- `--verbosity` **[str]** : **DEPRECATED** - Use `LOGLEVEL` environment variable instead.
- `--write_out` **[flag, default: False]** : Print first document's prompt and target for each task. Useful for debugging prompt formatting.
- `--show_config` **[flag, default: False]** : Display full task configurations after evaluation. Shows all non-default settings from task YAML files.
- `--wandb_args` **[comma-sep str → dict]** : Weights & Biases integration. Arguments for `wandb.init()`:
- Example: `project=my-project,name=run-1,tags=test`
- Special: `step=123` sets logging step
- See [W&B docs](https://docs.wandb.ai/ref/python/init) for all options
- `--wandb_config_args` **[comma-sep str → dict]** : Additional W&B config arguments, same format as `--wandb_args`.
- `--hf_hub_log_args` **[comma-sep str → dict]** : Hugging Face Hub logging configuration. Format: `key1=value1,key2=value2`. Options:
- `hub_results_org`: Organization name (default: token owner)
- `details_repo_name`: Repository for detailed results
- `results_repo_name`: Repository for aggregated results
- `push_results_to_hub`: Enable pushing (`True`/`False`)
- `push_samples_to_hub`: Push samples (`True`/`False`, requires `--log_samples`)
- `public_repo`: Make repo public (`True`/`False`)
- `leaderboard_url`: Associated leaderboard URL
- `point_of_contact`: Contact email
- `gated`: Gate the dataset (`True`/`False`)
- ~~`hub_repo_name`~~: Deprecated, use `details_repo_name` and `results_repo_name`
#### Advanced Options
- `--predict_only` **[flag, default: False]** : Generate outputs without computing metrics. Automatically enables `--log_samples`. Use to get raw model outputs.
- `--seed` **[int | comma-sep str → list[int], default: [0,1234,1234,1234]]** : Set random seeds for reproducibility:
- Single integer: Same seed for all (e.g., `42`)
- Four values: `python,numpy,torch,fewshot` seeds (e.g., `0,1234,8,52`)
- Use `None` to skip setting a seed (e.g., `0,None,8,52`)
Default preserves backward compatibility.
- `--trust_remote_code` **[flag, default: False]** : Allow executing remote code from Hugging Face Hub. **Security Risk**: Required for some models with custom code.
- `--confirm_run_unsafe_code` **[flag, default: False]** : Acknowledge risks when running tasks that execute arbitrary Python code (e.g., code generation tasks).
- `--metadata` **[json str → dict]** : Additional metadata for specific tasks. Format: `'{"key": "value"}'`. Required by tasks like RULER that need extra configuration.
## External Library Usage
......
import logging
import os
from .api import metrics, model, registry # initializes the registries
from .filters import *
__version__ = "0.4.9"
__version__ = "0.4.9.1"
# Lazy-load .evaluator module to improve CLI startup
......
import argparse
import json
import logging
import os
import sys
from functools import partial
from pathlib import Path
from typing import Union
from rich.traceback import install
import lm_eval.tasks
from lm_eval._cli.harness import HarnessCLI
from lm_eval.utils import setup_logging
def try_parse_json(value: str) -> Union[str, dict, None]:
if value is None:
return None
try:
return json.loads(value)
except json.JSONDecodeError:
if "{" in value:
raise argparse.ArgumentTypeError(
f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings."
)
return value
install(show_locals=True)
def _int_or_none_list_arg_type(
min_len: int, max_len: int, defaults: str, value: str, split_char: str = ","
):
def parse_value(item):
item = item.strip().lower()
if item == "none":
return None
try:
return int(item)
except ValueError:
raise argparse.ArgumentTypeError(f"{item} is not an integer or None")
items = [parse_value(v) for v in value.split(split_char)]
num_items = len(items)
if num_items == 1:
# Makes downstream handling the same for single and multiple values
items = items * max_len
elif num_items < min_len or num_items > max_len:
raise argparse.ArgumentTypeError(
f"Argument requires {max_len} integers or None, separated by '{split_char}'"
)
elif num_items != max_len:
logging.warning(
f"Argument requires {max_len} integers or None, separated by '{split_char}'. "
"Missing values will be filled with defaults."
)
default_items = [parse_value(v) for v in defaults.split(split_char)]
items.extend(
default_items[num_items:]
) # extend items list with missing defaults
return items
def check_argument_types(parser: argparse.ArgumentParser):
"""
Check to make sure all CLI args are typed, raises error if not
"""
for action in parser._actions:
if action.dest != "help" and not action.const:
if action.type is None:
raise ValueError(
f"Argument '{action.dest}' doesn't have a type specified."
)
else:
continue
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument(
"--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`"
)
parser.add_argument(
"--tasks",
"-t",
default=None,
type=str,
metavar="task1,task2",
help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above",
)
parser.add_argument(
"--model_args",
"-a",
default="",
type=try_parse_json,
help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'""",
)
parser.add_argument(
"--num_fewshot",
"-f",
type=int,
default=None,
metavar="N",
help="Number of examples in few-shot context",
)
parser.add_argument(
"--batch_size",
"-b",
type=str,
default=1,
metavar="auto|auto:N|N",
help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
)
parser.add_argument(
"--max_batch_size",
type=int,
default=None,
metavar="N",
help="Maximal batch size to try with --batch_size auto.",
)
parser.add_argument(
"--device",
type=str,
default=None,
help="Device to use (e.g. cuda, cuda:0, cpu).",
)
parser.add_argument(
"--output_path",
"-o",
default=None,
type=str,
metavar="DIR|DIR/file.json",
help="Path where result metrics will be saved. Can be either a directory or a .json file. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
)
parser.add_argument(
"--limit",
"-L",
type=float,
default=None,
metavar="N|0<N<1",
help="Limit the number of examples per task. "
"If <1, limit is a percentage of the total number of examples.",
)
parser.add_argument(
"--samples",
"-E",
default=None,
type=str,
metavar="/path/to/json",
help='JSON string or path to JSON file containing doc indices of selected examples to test. Format: {"task_name":[indices],...}',
)
parser.add_argument(
"--use_cache",
"-c",
type=str,
default=None,
metavar="DIR",
help="A path to a sqlite db file for caching model responses. `None` if not caching.",
)
parser.add_argument(
"--cache_requests",
type=str,
default=None,
choices=["true", "refresh", "delete"],
help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
)
parser.add_argument(
"--check_integrity",
action="store_true",
help="Whether to run the relevant part of the test suite for the tasks.",
)
parser.add_argument(
"--write_out",
"-w",
action="store_true",
default=False,
help="Prints the prompt for the first few documents.",
)
parser.add_argument(
"--log_samples",
"-s",
action="store_true",
default=False,
help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
)
parser.add_argument(
"--system_instruction",
type=str,
default=None,
help="System instruction to be used in the prompt",
)
parser.add_argument(
"--apply_chat_template",
type=str,
nargs="?",
const=True,
default=False,
help=(
"If True, apply chat template to the prompt. "
"Providing `--apply_chat_template` without an argument will apply the default chat template to the prompt. "
"To apply a specific template from the available list of templates, provide the template name as an argument. "
"E.g. `--apply_chat_template template_name`"
),
)
parser.add_argument(
"--fewshot_as_multiturn",
action="store_true",
default=False,
help="If True, uses the fewshot as a multi-turn conversation",
)
parser.add_argument(
"--show_config",
action="store_true",
default=False,
help="If True, shows the the full config of all tasks at the end of the evaluation.",
)
parser.add_argument(
"--include_path",
type=str,
default=None,
metavar="DIR",
help="Additional path to include if there are external tasks to include.",
)
parser.add_argument(
"--gen_kwargs",
type=try_parse_json,
default=None,
help=(
"Either comma delimited string or JSON formatted arguments for model generation on greedy_until tasks,"
""" e.g. '{"temperature":0.7,"until":["hello"]}' or temperature=0,top_p=0.1."""
),
)
parser.add_argument(
"--verbosity",
"-v",
type=str.upper,
default=None,
metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
help="(Deprecated) Controls logging verbosity level. Use the `LOGLEVEL` environment variable instead. Set to DEBUG for detailed output when testing or adding new task configurations.",
)
parser.add_argument(
"--wandb_args",
type=str,
default="",
help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
)
parser.add_argument(
"--wandb_config_args",
type=str,
default="",
help="Comma separated string arguments passed to wandb.config.update. Use this to trace parameters that aren't already traced by default. eg. `lr=0.01,repeats=3",
)
parser.add_argument(
"--hf_hub_log_args",
type=str,
default="",
help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`",
)
parser.add_argument(
"--predict_only",
"-x",
action="store_true",
default=False,
help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
)
default_seed_string = "0,1234,1234,1234"
parser.add_argument(
"--seed",
type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
default=default_seed_string, # for backward compatibility
help=(
"Set seed for python's random, numpy, torch, and fewshot sampling.\n"
"Accepts a comma-separated list of 4 values for python's random, numpy, torch, and fewshot sampling seeds, "
"respectively, or a single integer to set the same seed for all four.\n"
f"The values are either an integer or 'None' to not set the seed. Default is `{default_seed_string}` "
"(for backward compatibility).\n"
"E.g. `--seed 0,None,8,52` sets `random.seed(0)`, `torch.manual_seed(8)`, and fewshot sampling seed to 52. "
"Here numpy's seed is not set since the second value is `None`.\n"
"E.g, `--seed 42` sets all four seeds to 42."
),
)
parser.add_argument(
"--trust_remote_code",
action="store_true",
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
)
parser.add_argument(
"--confirm_run_unsafe_code",
action="store_true",
help="Confirm that you understand the risks of running unsafe code for tasks that require it",
)
parser.add_argument(
"--metadata",
type=json.loads,
default=None,
help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""",
)
return parser
def parse_eval_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
check_argument_types(parser)
return parser.parse_args()
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if not args:
# we allow for args to be passed externally, else we parse them ourselves
parser = setup_parser()
args = parse_eval_args(parser)
# defer loading `lm_eval` submodules for faster CLI load
from lm_eval import evaluator, utils
from lm_eval.evaluator import request_caching_arg_to_dict
from lm_eval.loggers import EvaluationTracker, WandbLogger
from lm_eval.tasks import TaskManager
from lm_eval.utils import (
handle_non_serializable,
make_table,
simple_parse_args_string,
)
if args.wandb_args:
wandb_args_dict = simple_parse_args_string(args.wandb_args)
wandb_config_args_dict = simple_parse_args_string(args.wandb_config_args)
wandb_logger = WandbLogger(wandb_args_dict, wandb_config_args_dict)
utils.setup_logging(args.verbosity)
eval_logger = logging.getLogger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# update the evaluation tracker args with the output path and the HF token
if args.output_path:
args.hf_hub_log_args += f",output_path={args.output_path}"
if os.environ.get("HF_TOKEN", None):
args.hf_hub_log_args += f",token={os.environ.get('HF_TOKEN')}"
evaluation_tracker_args = simple_parse_args_string(args.hf_hub_log_args)
evaluation_tracker = EvaluationTracker(**evaluation_tracker_args)
if args.predict_only:
args.log_samples = True
if (args.log_samples or args.predict_only) and not args.output_path:
raise ValueError(
"Specify --output_path if providing --log_samples or --predict_only"
)
if args.fewshot_as_multiturn and args.apply_chat_template is False:
raise ValueError(
"When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)."
)
if args.include_path is not None:
eval_logger.info(f"Including path: {args.include_path}")
metadata = (
simple_parse_args_string(args.model_args)
if isinstance(args.model_args, str)
else args.model_args
if isinstance(args.model_args, dict)
else {}
) | (
args.metadata
if isinstance(args.metadata, dict)
else simple_parse_args_string(args.metadata)
)
task_manager = TaskManager(include_path=args.include_path, metadata=metadata)
if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples:
eval_logger.warning(
"Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
)
if args.limit:
eval_logger.warning(
" --limit SHOULD ONLY BE USED FOR TESTING."
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if args.samples:
assert args.limit is None, (
"If --samples is not None, then --limit must be None."
)
if (samples := Path(args.samples)).is_file():
args.samples = json.loads(samples.read_text())
else:
args.samples = json.loads(args.samples)
if args.tasks is None:
eval_logger.error("Need to specify task to evaluate.")
sys.exit()
elif args.tasks == "list":
print(task_manager.list_all_tasks())
sys.exit()
elif args.tasks == "list_groups":
print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
sys.exit()
elif args.tasks == "list_tags":
print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False))
sys.exit()
elif args.tasks == "list_subtasks":
print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
sys.exit()
else:
if os.path.isdir(args.tasks):
import glob
task_names = []
yaml_path = os.path.join(args.tasks, "*.yaml")
for yaml_file in glob.glob(yaml_path):
config = lm_eval.tasks.load_yaml_config(yaml_file)
task_names.append(config)
else:
task_list = args.tasks.split(",")
task_names = task_manager.match_tasks(task_list)
for task in [task for task in task_list if task not in task_names]:
if os.path.isfile(task):
config = lm_eval.tasks.load_yaml_config(task)
task_names.append(config)
task_missing = [
task for task in task_list if task not in task_names and "*" not in task
] # we don't want errors if a wildcard ("*") task name was used
if task_missing:
missing = ", ".join(task_missing)
eval_logger.error(
f"Tasks were not found: {missing}\n"
f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
)
raise ValueError(
f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG' to troubleshoot task registration issues."
)
# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
if args.trust_remote_code:
eval_logger.info(
"Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`"
)
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
# because it's already been determined based on the prior env var before launching our
# script--`datasets` gets imported by lm_eval internally before these lines can update the env.
import datasets
from packaging.version import parse as vparse
if vparse(datasets.__version__) < vparse("4.0.0"):
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
if isinstance(args.model_args, dict):
args.model_args["trust_remote_code"] = True
else:
args.model_args = args.model_args + ",trust_remote_code=True"
(
eval_logger.info(f"Selected Tasks: {task_names}")
if eval_logger.getEffectiveLevel() >= logging.INFO
else print(f"Selected Tasks: {task_names}")
)
request_caching_args = request_caching_arg_to_dict(
cache_requests=args.cache_requests
)
results = evaluator.simple_evaluate(
model=args.model,
model_args=args.model_args,
tasks=task_names,
num_fewshot=args.num_fewshot,
batch_size=args.batch_size,
max_batch_size=args.max_batch_size,
device=args.device,
use_cache=args.use_cache,
limit=args.limit,
samples=args.samples,
check_integrity=args.check_integrity,
write_out=args.write_out,
log_samples=args.log_samples,
evaluation_tracker=evaluation_tracker,
system_instruction=args.system_instruction,
apply_chat_template=args.apply_chat_template,
fewshot_as_multiturn=args.fewshot_as_multiturn,
gen_kwargs=args.gen_kwargs,
task_manager=task_manager,
predict_only=args.predict_only,
random_seed=args.seed[0],
numpy_random_seed=args.seed[1],
torch_random_seed=args.seed[2],
fewshot_random_seed=args.seed[3],
confirm_run_unsafe_code=args.confirm_run_unsafe_code,
metadata=metadata,
**request_caching_args,
)
if results is not None:
if args.log_samples:
samples = results.pop("samples")
dumped = json.dumps(
results, indent=2, default=handle_non_serializable, ensure_ascii=False
)
if args.show_config:
print(dumped)
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
# Add W&B logging
if args.wandb_args:
try:
wandb_logger.post_init(results)
wandb_logger.log_eval_result()
if args.log_samples:
wandb_logger.log_eval_samples(samples)
except Exception as e:
eval_logger.info(f"Logging to Weights and Biases failed due to {e}")
evaluation_tracker.save_results_aggregated(
results=results, samples=samples if args.log_samples else None
)
if args.log_samples:
for task_name, config in results["configs"].items():
evaluation_tracker.save_results_samples(
task_name=task_name, samples=samples[task_name]
)
if (
evaluation_tracker.push_results_to_hub
or evaluation_tracker.push_samples_to_hub
):
evaluation_tracker.recreate_metadata_card()
print(
f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
)
print(make_table(results))
if "groups" in results:
print(make_table(results, "groups"))
if args.wandb_args:
# Tear down wandb run once all the logging is done.
wandb_logger.run.finish()
def cli_evaluate() -> None:
"""Main CLI entry point."""
setup_logging()
parser = HarnessCLI()
args = parser.parse_args()
parser.execute(args)
if __name__ == "__main__":
......
"""
CLI subcommands to run from terminal.
"""
import argparse
import sys
import textwrap
from lm_eval._cli.ls import List
from lm_eval._cli.run import Run
from lm_eval._cli.validate import Validate
class HarnessCLI:
"""Main CLI parser that manages all subcommands."""
def __init__(self):
self._parser = argparse.ArgumentParser(
prog="lm-eval",
description="Language Model Evaluation Harness",
epilog=textwrap.dedent("""
quick start:
# Basic evaluation
lm-eval run --model hf --model_args pretrained=gpt2 --tasks hellaswag
# List available tasks
lm-eval ls tasks
# Validate task configurations
lm-eval validate --tasks hellaswag,arc_easy
legacy compatibility:
The harness maintains backward compatibility with the original interface.
If no command is specified, 'run' is automatically inserted:
lm-eval --model hf --tasks hellaswag # Equivalent to 'lm-eval run --model hf --tasks hellaswag'
For documentation, visit: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md
"""),
formatter_class=argparse.RawDescriptionHelpFormatter,
)
self._parser.set_defaults(func=lambda args: self._parser.print_help())
self._subparsers = self._parser.add_subparsers(
dest="command", help="Available commands", metavar="COMMAND"
)
Run.create(self._subparsers)
List.create(self._subparsers)
Validate.create(self._subparsers)
def parse_args(self) -> argparse.Namespace:
"""Parse arguments using the main parser."""
if len(sys.argv) > 2 and sys.argv[1] not in self._subparsers.choices:
# Backward compatibility: arguments provided but no valid subcommand - insert 'run'
# TODO: add warning
sys.argv.insert(1, "run")
elif len(sys.argv) == 2 and "run" in sys.argv:
# if only 'run' is specified, ensure it is treated as a subcommand
self._subparsers.choices["run"].print_help()
sys.exit(0)
return self._parser.parse_args()
def execute(self, args: argparse.Namespace) -> None:
"""Main execution method that handles subcommands and legacy support."""
args.func(args)
import argparse
import textwrap
from lm_eval._cli.subcommand import SubCommand
class List(SubCommand):
"""Command for listing available tasks."""
def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs):
# Create and configure the parser
super().__init__(*args, **kwargs)
self._parser = subparsers.add_parser(
"ls",
help="List available tasks, groups, subtasks, or tags",
description="List available tasks, groups, subtasks, or tags from the evaluation harness.",
usage="lm-eval list [tasks|groups|subtasks|tags] [--include_path DIR]",
epilog=textwrap.dedent("""
examples:
# List all available tasks (includes groups, subtasks, and tags)
$ lm-eval ls tasks
# List only task groups (like 'mmlu', 'glue', 'superglue')
$ lm-eval ls groups
# List only individual subtasks (like 'mmlu_abstract_algebra')
$ lm-eval ls subtasks
# Include external task definitions
$ lm-eval ls tasks --include_path /path/to/external/tasks
# List tasks from multiple external paths
$ lm-eval ls tasks --include_path "/path/to/tasks1:/path/to/tasks2"
organization:
• Groups: Collections of tasks with aggregated metric across subtasks (e.g., 'mmlu')
• Subtasks: Individual evaluation tasks (e.g., 'mmlu_anatomy', 'hellaswag')
• Tags: Similar to groups but no aggregate metric (e.g., 'reasoning', 'knowledge', 'language')
• External Tasks: Custom tasks defined in external directories
evaluation usage:
After listing tasks, use them with the run command!
For more information tasks configs are defined in https://github.com/EleutherAI/lm-evaluation-harness/tree/main/lm_eval/tasks
"""),
formatter_class=argparse.RawDescriptionHelpFormatter,
)
self._add_args()
self._parser.set_defaults(func=self._execute)
def _add_args(self) -> None:
self._parser.add_argument(
"what",
choices=["tasks", "groups", "subtasks", "tags"],
nargs="?",
help="What to list: tasks (all), groups, subtasks, or tags",
)
self._parser.add_argument(
"--include_path",
type=str,
default=None,
metavar="DIR",
help="Additional path to include if there are external tasks.",
)
def _execute(self, args: argparse.Namespace) -> None:
"""Execute the list command."""
from lm_eval.tasks import TaskManager
task_manager = TaskManager(include_path=args.include_path)
if args.what == "tasks":
print(task_manager.list_all_tasks())
elif args.what == "groups":
print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
elif args.what == "subtasks":
print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
elif args.what == "tags":
print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False))
elif args.what is None:
self._parser.print_help()
import argparse
import json
import logging
import os
import textwrap
from functools import partial
from lm_eval._cli.subcommand import SubCommand
from lm_eval._cli.utils import (
_int_or_none_list_arg_type,
key_val_to_dict,
merge_dicts,
request_caching_arg_to_dict,
try_parse_json,
)
class Run(SubCommand):
"""Command for running language model evaluation."""
def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs):
super().__init__(*args, **kwargs)
self._parser = subparsers.add_parser(
"run",
help="Run the evaluation harness on specified tasks",
description="Evaluate language models on various benchmarks and tasks.",
usage="lm-eval run --model <model> --tasks <task> <task> --model_args <arg=value> <arg=value> [options]",
epilog=textwrap.dedent("""
examples:
# Basic evaluation with HuggingFace model
$ lm-eval run --model hf --model_args pretrained=gpt2 dtype=float32 --tasks hellaswag
# Evaluate on multiple tasks with few-shot examples
$ lm-eval run --model vllm --model_args pretrained=EleutherAI/gpt-j-6B --tasks arc_easy arc_challenge --num_fewshot 5
# Evaluation with custom generation parameters
$ lm-eval run --model hf --model_args pretrained=gpt2 --tasks lambada --gen_kwargs temperature=0.8 top_p=0.95 'stop=["\\n\\n"]'
# Use configuration file
$ lm-eval run --config my_config.yaml --tasks mmlu
For more information, see: https://github.com/EleutherAI/lm-evaluation-harness
"""),
formatter_class=argparse.RawDescriptionHelpFormatter,
)
self._add_args()
self._parser.set_defaults(func=self._execute)
def _add_args(self) -> None:
self._parser = self._parser
# Defaults are set in config/evaluate_config.py
config_group = self._parser.add_argument_group("configuration")
config_group.add_argument(
"--config",
"-C",
default=None,
type=str,
metavar="YAML_PATH",
help="Set initial arguments from YAML config",
)
# Model and Tasks
model_group = self._parser.add_argument_group("model and tasks")
model_group.add_argument(
"--model",
"-m",
type=str,
default=None,
metavar="MODEL_NAME",
help="Model name (default: hf)",
)
model_group.add_argument(
"--tasks",
"-t",
default=None,
type=str,
nargs="*",
metavar="TASK1 TASK2",
help=textwrap.dedent("""
Space or Comma-separated list of task names or groupings.
Use 'lm-eval list tasks' to see all available tasks.
""").strip(),
)
model_group.add_argument(
"--model_args",
"-a",
default=None,
nargs="*",
type=key_val_to_dict,
metavar="ARGS",
help="Model arguments as 'key=val,key2=val2' or `key=val` `key2=val2`",
)
# Evaluation Settings
eval_group = self._parser.add_argument_group("evaluation settings")
eval_group.add_argument(
"--num_fewshot",
"-f",
type=int,
default=None,
metavar="N",
help="Number of examples in few-shot context",
)
eval_group.add_argument(
"--batch_size",
"-b",
type=str,
default=argparse.SUPPRESS,
metavar="auto|auto:N|N",
help=textwrap.dedent(
"Batch size: 'auto', 'auto:N' (auto-tune N times), or integer (default: 1)"
),
)
eval_group.add_argument(
"--max_batch_size",
type=int,
default=None,
metavar="N",
help="Maximum batch size when using --batch_size auto",
)
eval_group.add_argument(
"--device",
type=str,
default=None,
metavar="DEVICE",
help="Device to use (e.g. cuda, cuda:0, cpu, mps)",
)
eval_group.add_argument(
"--gen_kwargs",
type=key_val_to_dict,
default=None,
nargs="*",
metavar="KWARGS",
help=textwrap.dedent(
'Generation arguments as `temperature=0,stop=["stop"]` or `key=val` `key2=val2`.'
"Values should be parsable with ast.literal_eval."
),
)
# Data and Output
data_group = self._parser.add_argument_group("data and output")
data_group.add_argument(
"--output_path",
"-o",
default=None,
type=str,
metavar="OUTPUT_PATH",
help="Output dir or json file for results (and samples)",
)
data_group.add_argument(
"--log_samples",
"-s",
action="store_true",
default=argparse.SUPPRESS,
help="Save all model outputs and documents for post-hoc analysis",
)
data_group.add_argument(
"--limit",
"-L",
type=float,
default=None,
metavar="N|0.0-1.0",
help="Limit examples per task (integer count or fraction)",
)
data_group.add_argument(
"--samples",
"-E",
default=None,
type=try_parse_json,
metavar='"task1": [1,2,3,4,...]"',
help=textwrap.dedent(
"`...` `...` Sample indices for inputs. Incompatible with --limit."
" Values be parsable with ast.literal_eval."
),
)
# Caching and Performance
cache_group = self._parser.add_argument_group("caching and performance")
cache_group.add_argument(
"--use_cache",
"-c",
type=str,
default=None,
metavar="CACHE_DIR",
help="SQLite database path for caching model outputs.",
)
cache_group.add_argument(
"--cache_requests",
type=request_caching_arg_to_dict,
default=None,
choices=["true", "refresh", "delete"],
help="Cache dataset request building (true|refresh|delete)",
)
cache_group.add_argument(
"--check_integrity",
action="store_true",
default=argparse.SUPPRESS,
help="Run task test suite validation",
)
# Prompt Formatting
template_group = self._parser.add_argument_group("instruct formatting")
template_group.add_argument(
"--system_instruction",
type=str,
default=None,
metavar="INSTRUCTION",
help="Add custom system instruction.",
)
template_group.add_argument(
"--apply_chat_template",
type=str,
nargs="?",
const=True,
default=argparse.SUPPRESS,
metavar="TEMPLATE",
help="Apply chat template to prompts (optional template name)",
)
template_group.add_argument(
"--fewshot_as_multiturn",
action="store_true",
default=argparse.SUPPRESS,
help="Use fewshot examples as multi-turn conversation",
)
# Task Management
task_group = self._parser.add_argument_group("task management")
task_group.add_argument(
"--include_path",
type=str,
default=None,
metavar="TASK_DIR",
help="Additional directory for external tasks",
)
# Logging and Tracking
logging_group = self._parser.add_argument_group("logging and tracking")
logging_group.add_argument(
"--verbosity",
"-v",
type=str.upper,
default=None,
metavar="LEVEL",
help="(Deprecated) Log level. Use LOGLEVEL env var instead",
)
logging_group.add_argument(
"--write_out",
"-w",
action="store_true",
default=argparse.SUPPRESS,
help="Print prompts for first few documents",
)
logging_group.add_argument(
"--show_config",
action="store_true",
default=argparse.SUPPRESS,
help="Display full task configuration after evaluation",
)
logging_group.add_argument(
"--wandb_args",
type=key_val_to_dict,
default=argparse.SUPPRESS,
metavar="ARGS",
help="Weights & Biases init arguments key=val key2=val2",
)
logging_group.add_argument(
"--wandb_config_args",
type=key_val_to_dict,
default=argparse.SUPPRESS,
metavar="ARGS",
help="Weights & Biases config arguments key=val key2=val2",
)
logging_group.add_argument(
"--hf_hub_log_args",
type=key_val_to_dict,
default=argparse.SUPPRESS,
metavar="ARGS",
help="Hugging Face Hub logging arguments key=val key2=val2",
)
# Advanced Options
advanced_group = self._parser.add_argument_group("advanced options")
advanced_group.add_argument(
"--predict_only",
"-x",
action="store_true",
default=argparse.SUPPRESS,
help="Save predictions only, skip metric computation",
)
default_seed_string = "0,1234,1234,1234"
advanced_group.add_argument(
"--seed",
type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
default=None,
metavar="SEED|S1,S2,S3,S4",
help=textwrap.dedent(f"""
Random seeds for python,numpy,torch,fewshot (default: {default_seed_string}).
Use single integer for all, or comma-separated list of 4 values.
Use 'None' to skip setting a seed. Example: --seed 42 or --seed 0,None,8,52
""").strip(),
)
advanced_group.add_argument(
"--trust_remote_code",
action="store_true",
default=argparse.SUPPRESS,
help="Allow executing remote code from Hugging Face Hub",
)
advanced_group.add_argument(
"--confirm_run_unsafe_code",
action="store_true",
default=argparse.SUPPRESS,
help="Confirm understanding of unsafe code execution risks",
)
advanced_group.add_argument(
"--metadata",
type=json.loads,
default=None,
metavar="`key=val` `key2=val2`",
help=textwrap.dedent(
"""`key=val` `key2=val` args parsable by ast.literal_eval (merged with model_args),
required for some tasks such as RULER"""
),
)
@staticmethod
def _execute(args: argparse.Namespace) -> None:
"""Runs the evaluation harness with the provided arguments."""
os.environ["TOKENIZERS_PARALLELISM"] = "false"
MERGE_ARGS_DICTS = [
"model_args",
"gen_kwargs",
"wandb_args",
"wandb_config_args",
"hf_hub_log_args",
]
for arg_name in MERGE_ARGS_DICTS:
if current_value := getattr(args, arg_name, None):
setattr(args, arg_name, merge_dicts(*current_value))
from lm_eval.config.evaluate_config import EvaluatorConfig
eval_logger = logging.getLogger(__name__)
# Create and validate config (most validation now occurs in EvaluationConfig)
cfg = EvaluatorConfig.from_cli(args)
from lm_eval import simple_evaluate
from lm_eval.loggers import EvaluationTracker, WandbLogger
from lm_eval.utils import handle_non_serializable, make_table
# Set up logging
if cfg.wandb_args:
wandb_logger = WandbLogger(cfg.wandb_args, cfg.wandb_config_args)
# Set up evaluation tracker
if cfg.output_path:
cfg.hf_hub_log_args["output_path"] = cfg.output_path
if os.environ.get("HF_TOKEN", None):
cfg.hf_hub_log_args["token"] = os.environ.get("HF_TOKEN")
evaluation_tracker = EvaluationTracker(**cfg.hf_hub_log_args)
# Create task manager (metadata already set up in config validation)
task_manager = cfg.process_tasks(cfg.metadata)
# Validation warnings (keep these in CLI as they're logging-specific)
if "push_samples_to_hub" in cfg.hf_hub_log_args and not cfg.log_samples:
eval_logger.warning(
"Pushing samples to the Hub requires --log_samples to be set."
)
# Log task selection (tasks already processed in config)
if cfg.include_path is not None:
eval_logger.info(f"Including path: {cfg.include_path}")
eval_logger.info(f"Selected Tasks: {cfg.tasks}")
# Run evaluation
results = simple_evaluate(
model=cfg.model,
model_args=cfg.model_args,
tasks=cfg.tasks,
num_fewshot=cfg.num_fewshot,
batch_size=cfg.batch_size,
max_batch_size=cfg.max_batch_size,
device=cfg.device,
use_cache=cfg.use_cache,
cache_requests=cfg.cache_requests.get("cache_requests", False),
rewrite_requests_cache=cfg.cache_requests.get(
"rewrite_requests_cache", False
),
delete_requests_cache=cfg.cache_requests.get(
"delete_requests_cache", False
),
limit=cfg.limit,
samples=cfg.samples,
check_integrity=cfg.check_integrity,
write_out=cfg.write_out,
log_samples=cfg.log_samples,
evaluation_tracker=evaluation_tracker,
system_instruction=cfg.system_instruction,
apply_chat_template=cfg.apply_chat_template,
fewshot_as_multiturn=cfg.fewshot_as_multiturn,
gen_kwargs=cfg.gen_kwargs,
task_manager=task_manager,
verbosity=cfg.verbosity,
predict_only=cfg.predict_only,
random_seed=cfg.seed[0] if cfg.seed else None,
numpy_random_seed=cfg.seed[1] if cfg.seed else None,
torch_random_seed=cfg.seed[2] if cfg.seed else None,
fewshot_random_seed=cfg.seed[3] if cfg.seed else None,
confirm_run_unsafe_code=cfg.confirm_run_unsafe_code,
metadata=cfg.metadata,
)
# Process results
if results is not None:
if cfg.log_samples:
samples = results.pop("samples")
dumped = json.dumps(
results, indent=2, default=handle_non_serializable, ensure_ascii=False
)
if cfg.show_config:
print(dumped)
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
# W&B logging
if cfg.wandb_args:
try:
wandb_logger.post_init(results)
wandb_logger.log_eval_result()
if cfg.log_samples:
wandb_logger.log_eval_samples(samples)
except Exception as e:
eval_logger.info(f"Logging to W&B failed: {e}")
# Save results
evaluation_tracker.save_results_aggregated(
results=results, samples=samples if cfg.log_samples else None
)
if cfg.log_samples:
for task_name, _ in results["configs"].items():
evaluation_tracker.save_results_samples(
task_name=task_name, samples=samples[task_name]
)
if (
evaluation_tracker.push_results_to_hub
or evaluation_tracker.push_samples_to_hub
):
evaluation_tracker.recreate_metadata_card()
# Print results
cfg.model_args.pop("trust_remote_code", None)
print(
f"{cfg.model} ({cfg.model_args}), gen_kwargs: ({cfg.gen_kwargs}), "
f"limit: {cfg.limit}, num_fewshot: {cfg.num_fewshot}, "
f"batch_size: {cfg.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
)
print(make_table(results))
if "groups" in results:
print(make_table(results, "groups"))
if cfg.wandb_args:
wandb_logger.run.finish()
import argparse
from abc import ABC, abstractmethod
class SubCommand(ABC):
"""Base class for all subcommands."""
def __init__(self, *args, **kwargs):
pass
@classmethod
def create(cls, subparsers: argparse._SubParsersAction):
"""Factory method to create and register a command instance."""
return cls(subparsers)
@abstractmethod
def _add_args(self) -> None:
"""Add arguments specific to this subcommand."""
pass
import argparse
import ast
import json
import logging
from typing import Any, Optional, Union
def try_parse_json(value: Union[str, dict, None]) -> Union[str, dict, None]:
"""Try to parse a string as JSON. If it fails, return the original string."""
if value is None:
return None
if isinstance(value, dict):
return value
try:
return json.loads(value)
except json.JSONDecodeError:
if "{" in value:
raise ValueError(
f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings."
)
return value
def _int_or_none_list_arg_type(
min_len: int, max_len: int, defaults: str, value: str, split_char: str = ","
) -> list[Union[int, None]]:
"""Parses a string of integers or 'None' values separated by a specified character into a list.
Validates the number of items against specified minimum and maximum lengths and fills missing values with defaults."""
def parse_value(item):
"""Parses an individual item, converting it to an integer or `None`."""
item = item.strip().lower()
if item == "none":
return None
try:
return int(item)
except ValueError:
raise ValueError(f"{item} is not an integer or None")
items = [parse_value(v) for v in value.split(split_char)]
num_items = len(items)
if num_items == 1:
items = items * max_len
elif num_items < min_len or num_items > max_len:
raise ValueError(
f"Argument requires {max_len} integers or None, separated by '{split_char}'"
)
elif num_items != max_len:
logging.warning(
f"Argument requires {max_len} integers or None, separated by '{split_char}'. "
"Missing values will be filled with defaults."
)
default_items = [parse_value(v) for v in defaults.split(split_char)]
items.extend(default_items[num_items:])
return items
def request_caching_arg_to_dict(cache_requests: Optional[str]) -> dict[str, bool]:
"""Convert a request caching argument to a dictionary."""
if cache_requests is None:
return {}
request_caching_args = {
"cache_requests": cache_requests in {"true", "refresh"},
"rewrite_requests_cache": cache_requests == "refresh",
"delete_requests_cache": cache_requests == "delete",
}
return request_caching_args
def check_argument_types(parser: argparse.ArgumentParser) -> None:
"""
Check to make sure all CLI args are typed, raises error if not
"""
for action in parser._actions:
# Skip help, subcommands, and const actions
if action.dest in ["help", "command"] or action.const is not None:
continue
if action.type is None:
raise ValueError(f"Argument '{action.dest}' doesn't have a type specified.")
else:
continue
def handle_cli_value_string(arg: str) -> Any:
if arg.lower() == "true":
return True
elif arg.lower() == "false":
return False
elif arg.isnumeric():
return int(arg)
try:
return float(arg)
except ValueError:
try:
return ast.literal_eval(arg)
except (ValueError, SyntaxError):
return arg
def key_val_to_dict(args: str) -> dict:
"""Parse model arguments from a string into a dictionary."""
return (
{
k: handle_cli_value_string(v)
for k, v in (item.split("=") for item in args.split(","))
}
if args
else {}
)
def merge_dicts(*dicts):
return {k: v for d in dicts for k, v in d.items()}
import argparse
import sys
import textwrap
from lm_eval._cli.subcommand import SubCommand
class Validate(SubCommand):
"""Command for validating tasks."""
def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs):
# Create and configure the self._parser
super().__init__(*args, **kwargs)
self._parser = subparsers.add_parser(
"validate",
help="Validate task configurations",
description="Validate task configurations and check for errors.",
usage="lm-eval validate --tasks <task1,task2> [--include_path DIR]",
epilog=textwrap.dedent("""
examples:
# Validate a single task
lm-eval validate --tasks hellaswag
# Validate multiple tasks
lm-eval validate --tasks arc_easy,arc_challenge,hellaswag
# Validate a task group
lm-eval validate --tasks mmlu
# Validate tasks with external definitions
lm-eval validate --tasks my_custom_task --include_path ./custom_tasks
# Validate tasks from multiple external paths
lm-eval validate --tasks custom_task1,custom_task2 --include_path "/path/to/tasks1:/path/to/tasks2"
validation check:
The validate command performs several checks:
• Task existence: Verifies all specified tasks are available
• Configuration syntax: Checks YAML/JSON configuration files
• Dataset access: Validates dataset paths and configurations
• Required fields: Ensures all mandatory task parameters are present
• Metric definitions: Verifies metric functions and aggregation methods
• Filter pipelines: Validates filter chains and their parameters
• Template rendering: Tests prompt templates with sample data
task config files:
Tasks are defined using YAML configuration files with these key sections:
• task: Task name and metadata
• dataset_path: HuggingFace dataset identifier
• doc_to_text: Template for converting documents to prompts
• doc_to_target: Template for extracting target answers
• metric_list: List of evaluation metrics to compute
• output_type: Type of model output (loglikelihood, generate_until, etc.)
• filter_list: Post-processing filters for model outputs
common errors:
• Missing required fields in YAML configuration
• Invalid dataset paths or missing dataset splits
• Malformed Jinja2 templates in doc_to_text/doc_to_target
• Undefined metrics or aggregation functions
• Invalid filter names or parameters
• Circular dependencies in task inheritance
• Missing external task files when using --include_path
debugging tips:
• Use --include_path to test external task definitions
• Check task configuration files for syntax errors
• Verify dataset access and authentication if needed
• Use 'lm-eval list tasks' to see available tasks
For task configuration guide, see: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/task_guide.md
"""),
formatter_class=argparse.RawDescriptionHelpFormatter,
)
self._add_args()
self._parser.set_defaults(func=self._execute)
def _add_args(self) -> None:
self._parser.add_argument(
"--tasks",
"-t",
required=True,
type=str,
metavar="TASK1,TASK2",
help="Comma-separated list of task names to validate",
)
self._parser.add_argument(
"--include_path",
type=str,
default=None,
metavar="DIR",
help="Additional path to include if there are external tasks.",
)
def _execute(self, args: argparse.Namespace) -> None:
"""Execute the validate command."""
from lm_eval.tasks import TaskManager
task_manager = TaskManager(include_path=args.include_path)
task_list = args.tasks.split(",")
print(f"Validating tasks: {task_list}")
# For now, just validate that tasks exist
task_names = task_manager.match_tasks(task_list)
task_missing = [task for task in task_list if task not in task_names]
if task_missing:
missing = ", ".join(task_missing)
print(f"Tasks not found: {missing}")
sys.exit(1)
else:
print("All tasks found and valid")
from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Callable, Iterable, List, Union
from typing import Protocol, runtime_checkable
from lm_eval.api.instance import Instance
class Filter(ABC):
@runtime_checkable
class Filter(Protocol):
"""
Filter classes operate on a per-task level.
They take all model outputs (`instance.resps` for all `task.instances`)
......@@ -19,8 +20,9 @@ class Filter(ABC):
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
@abstractmethod
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
"""
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
......@@ -40,9 +42,9 @@ class FilterEnsemble:
"""
name: str
filters: List[Callable[[], Filter]]
filters: list[type[Filter]]
def apply(self, instances: List[Instance]) -> None:
def apply(self, instances: list[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
resps, docs = list(resps), list(docs)
......
from dataclasses import asdict, dataclass
from dataclasses import asdict, dataclass, field
from inspect import getsource
from typing import Callable, Optional, Union
from datasets.features.pdf import field
@dataclass
class AggMetricConfig(dict):
metric: Optional[str] = None
aggregation: Optional[str] = "mean"
weight_by_size: Optional[str] = False
weight_by_size: bool = False
# list of filter names which should be incorporated into the aggregated metric.
filter_list: Optional[Union[str, list]] = "none"
......@@ -31,6 +29,7 @@ class GroupConfig:
aggregate_metric_list: Optional[
Union[list[AggMetricConfig], AggMetricConfig, dict]
] = None
version: Optional[str] = None
metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
)
......@@ -68,6 +67,11 @@ class GroupConfig:
AggMetricConfig(**item) if isinstance(item, dict) else item
for item in self.aggregate_metric_list
]
self.version = (
self.version or self.metadata.get("version", "1.0")
if self.metadata
else "1.0"
)
def to_dict(self, keep_callable: bool = False) -> dict:
"""dumps the current config as a dictionary object, as a printable format.
......
......@@ -14,10 +14,23 @@ class Instance:
arguments: tuple
idx: int
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None)
default_factory=lambda: (None, None, None),
metadata=dict(
description="Metadata tuple containing task name, document ID, and number of repeats."
),
)
resps: list = field(
default_factory=list,
metadata=dict(
description="List of responses from the model for this instance."
),
)
filtered_resps: dict = field(
default_factory=dict,
metadata=dict(
description="List of filtered responses for this instance, keyed by filter name."
),
)
resps: list = field(default_factory=list)
filtered_resps: dict = field(default_factory=dict)
# initialized after init
task_name: Optional[str] = None
......@@ -29,7 +42,7 @@ class Instance:
self.task_name, self.doc_id, self.repeats = self.metadata
@property
def args(self):
def args(self) -> tuple:
"""
Returns (string,) where `string` is the string to calculate loglikelihood over
"""
......
from __future__ import annotations
import logging
import math
import os
import random
import re
import string
from collections.abc import Iterable
from typing import Callable, List, Optional, Sequence, TypeVar
from collections.abc import Callable, Iterable, Sequence
from typing import Generic, TypeVar
import numpy as np
import sacrebleu
from lm_eval.api.registry import register_aggregation, register_metric
......@@ -25,36 +26,36 @@ def bypass_agg(arr):
@register_aggregation("nanmean")
def nanmean(arr):
def nanmean(arr: list[float]) -> float:
if len(arr) == 0 or all(np.isnan(arr)):
return np.nan
return np.nanmean(arr)
@register_aggregation("mean")
def mean(arr):
def mean(arr: Sequence[float]) -> float:
return sum(arr) / len(arr)
@register_aggregation("median")
def median(arr):
def median(arr: list[float]) -> float:
return arr[len(arr) // 2]
# Certain metrics must be calculated across all documents in a benchmark.
# We use them as aggregation metrics, paired with no-op passthrough metric fns.
@register_aggregation("perplexity")
def perplexity(items):
def perplexity(items: list[float]) -> float:
return math.exp(-mean(items))
@register_aggregation("weighted_perplexity")
def weighted_perplexity(items):
def weighted_perplexity(items: list[tuple[float, float]]) -> float:
return math.exp(-weighted_mean(items))
@register_aggregation("bits_per_byte")
def bits_per_byte(items):
def bits_per_byte(items: list[tuple[float, float]]) -> float:
return -weighted_mean(items) / math.log(2)
......@@ -71,7 +72,7 @@ def f1_score(items):
@register_aggregation("matthews_corrcoef")
def matthews_corrcoef(items):
def matthews_corrcoef(items: Iterable[tuple[int, int] | tuple[str, str]]) -> float:
from sklearn.metrics import matthews_corrcoef
unzipped_list = list(zip(*items))
......@@ -81,7 +82,7 @@ def matthews_corrcoef(items):
@register_aggregation("bleu")
def bleu(items):
def bleu(items: Iterable[tuple[str, str]]):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where
......@@ -92,6 +93,8 @@ def bleu(items):
Higher is better
"""
import sacrebleu
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
......@@ -107,6 +110,8 @@ def chrf(items):
Higher is better # TODO I think
"""
import sacrebleu
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
......@@ -114,7 +119,7 @@ def chrf(items):
@register_aggregation("ter")
def ter(items):
def ter(items: Iterable[tuple[str, str]]):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
of the references
......@@ -123,6 +128,8 @@ def ter(items):
Lower is better
"""
import sacrebleu
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
......@@ -130,7 +137,9 @@ def ter(items):
@register_aggregation("brier_score")
def brier_score(items): # This is a passthrough function
def brier_score(
items: Iterable[tuple[str, float]],
): # This is a passthrough function
gold, predictions = list(zip(*items))
bs, num_class = np.array(predictions).shape
......@@ -198,13 +207,48 @@ def acc_mutual_info_fn(items): # This is a passthrough function
# See the License for the specific language governing permissions and
# limitations under the License.
def exact_match_hf_evaluate(
predictions,
references,
regexes_to_ignore=None,
ignore_case=False,
ignore_punctuation=False,
ignore_numbers=False,
predictions: Iterable[str] | str,
references: Iterable[str] | str,
regexes_to_ignore: list[str] | None = None,
ignore_case: bool = False,
ignore_punctuation: bool = False,
ignore_numbers: bool = False,
multi_target: bool = False,
):
"""
Compute exact match scores between predictions and references.
This function computes the exact match score by comparing predictions
and references. It supports optional preprocessing steps such as ignoring
case, punctuation, numbers, and specific regex patterns.
Note:
predictions and references can have different lengths.
numpy broadcasting rule applies
Args:
predictions (Iterable[str] | str): The predicted strings to evaluate.
references (Iterable[str] | str): The reference strings to compare against.
regexes_to_ignore (list[str], optional): A list of regex patterns to remove
from both predictions and references before comparison. Defaults to None.
ignore_case (bool, optional): If True, ignores case differences during comparison.
Defaults to False.
ignore_punctuation (bool, optional): If True, removes punctuation from strings
before comparison. Defaults to False.
ignore_numbers (bool, optional): If True, removes numeric characters from strings
before comparison. Defaults to False.
multi_target (bool, optional): If True, returns 1.0 if any prediction matches any
reference, otherwise 0.0. Defaults to False.
Returns:
dict: A dictionary containing the exact match score:
- "exact_match" (float): The mean exact match score or 1.0/0.0 if `multi_target` is True.
"""
predictions, references = list(predictions), list(references)
assert len(predictions) == len(references) if not multi_target else True, (
"predictions and references must have the same length unless `multi_target` is True"
)
if regexes_to_ignore is not None:
for s in regexes_to_ignore:
predictions = np.array([re.sub(s, "", x) for x in predictions])
......@@ -229,7 +273,11 @@ def exact_match_hf_evaluate(
score_list = predictions == references
return {"exact_match": np.mean(score_list)}
return {
"exact_match": np.mean(score_list)
if not multi_target
else float(np.any(score_list))
}
###
......@@ -241,8 +289,8 @@ def exact_match_hf_evaluate(
output_type="generate_until",
aggregation="mean",
)
def exact_match_fn(**kwargs):
return exact_match_hf_evaluate(**kwargs)
def exact_match_fn(references: list[str], predictions: list[str], **kwargs):
return exact_match_hf_evaluate(predictions, references, **kwargs)
@register_metric(
......@@ -261,7 +309,7 @@ def perplexity_fn(items): # This is a passthrough function
output_type="loglikelihood_rolling",
aggregation="weighted_perplexity",
)
def word_perplexity_fn(items): # This is a passthrough function
def word_perplexity_fn(items: T) -> T: # This is a passthrough function
return items
......@@ -271,7 +319,7 @@ def word_perplexity_fn(items): # This is a passthrough function
output_type="loglikelihood_rolling",
aggregation="weighted_perplexity",
)
def byte_perplexity_fn(items): # This is a passthrough function
def byte_perplexity_fn(items: T) -> T: # This is a passthrough function
return items
......@@ -281,7 +329,7 @@ def byte_perplexity_fn(items): # This is a passthrough function
output_type="loglikelihood_rolling",
aggregation="bits_per_byte",
)
def bits_per_byte_fn(items): # This is a passthrough function
def bits_per_byte_fn(items: T) -> T: # This is a passthrough function
return items
......@@ -290,7 +338,7 @@ def pop_stddev(arr):
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
def sample_stddev(arr: Sequence[T]) -> float:
def sample_stddev(arr: Sequence[float]) -> float:
mu = mean(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
......@@ -411,7 +459,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(scores_for_ground_truths)
def weighted_mean(items):
def weighted_mean(items: list[tuple[float, float]]) -> float:
a, b = zip(*items)
return sum(a) / sum(b)
......@@ -422,15 +470,15 @@ def is_non_str_iterable(obj):
def _sacreformat(refs, preds):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (List[str], List[List[str])
# Sacrebleu expects (list[str], list[list[str])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# Note [ref1_stream] is the first reference for each pred.
# So lists are size N and (M, N) for N preds and M possible refs for each pred
# This is a different order of dimensions that I would expect
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds
# We expect refs to be list[str] or list[list[str]], the outer list corresponding to preds
# Must become list[list[str]] with the inner list corresponding to preds
if not is_non_str_iterable(refs):
refs = list(refs)
if not is_non_str_iterable(refs[0]):
......@@ -438,7 +486,7 @@ def _sacreformat(refs, preds):
refs = list(zip(*refs))
# Note the number of refs in each ref list much match the number of preds
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
# We expect preds to be list[str] or list[list[str]]. Must become list[str]
if not is_non_str_iterable(preds):
preds = list(preds)
if is_non_str_iterable(preds[0]):
......@@ -451,7 +499,7 @@ def _sacreformat(refs, preds):
# stderr stuff
class _bootstrap_internal:
class _bootstrap_internal(Generic[T]):
"""
Pool worker: `(i, xs)` → `n` bootstrap replicates
of `f(xs)`using a RNG seeded with `i`.
......@@ -534,7 +582,7 @@ def bootstrap_stderr(
def stderr_for_metric(
metric: Callable[[Sequence[T]], float], bootstrap_iters: int
) -> Optional[Callable[[Sequence[T]], float]]:
) -> Callable[[Sequence[T]], float] | None:
"""
Return a function that estimates the standard error of `metric(xs)`.
......@@ -564,10 +612,10 @@ def stderr_for_metric(
stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
return stderr.get(metric, None)
return stderr.get(metric)
def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
def pooled_sample_stderr(stderrs: list[float], sizes: list[int]):
# Used to aggregate bootstrapped stderrs across subtasks in a group,
# when we are weighting by the size of each subtask.
#
......@@ -585,7 +633,7 @@ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
return np.sqrt(pooled_sample_var / sum(sizes))
def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None):
def combined_sample_stderr(stderrs: list[float], sizes: list[int], metrics=None):
assert metrics is not None, (
"Need to pass a list of each subtask's metric for this stderr aggregation"
)
......@@ -617,7 +665,9 @@ def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None)
return np.sqrt(variance)
def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
def aggregate_subtask_metrics(
metrics: list[float], sizes: list[float], weight_by_size: bool = True
):
# A helper function that is used to aggregate
# subtask scores cross-task.
# TODO: does not hold for non-mean aggregations
......@@ -626,4 +676,4 @@ def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
assert len(metrics) == len(sizes)
return sum([metric * size for metric, size in zip(metrics, sizes)]) / sum(sizes)
return sum(metric * size for metric, size in zip(metrics, sizes)) / sum(sizes)
from __future__ import annotations
import abc
import hashlib
import json
import logging
import os
from typing import TYPE_CHECKING, Any, Iterable, Optional, Type, TypeVar, Union
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, TypeVar
from tqdm import tqdm
......@@ -24,17 +27,17 @@ T = TypeVar("T", bound="LM")
class LM(abc.ABC):
def __init__(self) -> None:
"""Defines the interface that should be implemented by all LM subclasses.
LMs are assumed to take text (strings) as input and yield strings as output
LMs are assumed to take text (strings) as input and yield strings or logprobabilities as output
(inputs/outputs should be tokenization-agnostic.)
"""
# set rank and world size to a single process, by default.
self._rank = 0
self._world_size = 1
self.cache_hook: "CacheHook" = CacheHook(None)
self.cache_hook: CacheHook = CacheHook(None)
@abc.abstractmethod
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.
......@@ -59,7 +62,7 @@ class LM(abc.ABC):
pass
@abc.abstractmethod
def loglikelihood_rolling(self, requests) -> list[float]:
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
......@@ -67,7 +70,7 @@ class LM(abc.ABC):
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
which may simply concatenate multiple documents together.
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
multiple chunks, the last input will still a full-sized context.
multiple chunks, the last input will still have full-sized context.
Example:
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
Prefix: BOS/EOS
......@@ -101,7 +104,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length
@abc.abstractmethod
def generate_until(self, requests) -> list[str]:
def generate_until(self, requests: list[Instance]) -> list[str]:
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
......@@ -118,7 +121,7 @@ class LM(abc.ABC):
pass
def apply_chat_template(
self, chat_history: list[dict[str, str]], add_generation_prompt=True
self, chat_history: list[dict], add_generation_prompt=True
) -> str:
"""
Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
......@@ -137,7 +140,7 @@ class LM(abc.ABC):
@classmethod
def create_from_arg_string(
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
cls: type[T], arg_string: str, additional_config: dict | None = None
) -> T:
"""
Creates an instance of the LM class using the given argument string and additional config.
......@@ -156,7 +159,7 @@ class LM(abc.ABC):
@classmethod
def create_from_arg_obj(
cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None
cls: type[T], arg_dict: dict, additional_config: dict | None = None
) -> T:
"""
Creates an instance of the LM class using the given arg_obj
......@@ -176,14 +179,16 @@ class LM(abc.ABC):
return cls(**arg_dict, **additional_config)
@property
def rank(self):
def rank(self) -> int:
"""Returns the rank of the current process in a distributed setting."""
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
return self._rank
@property
def world_size(self):
def world_size(self) -> int:
"""Returns the total number of processes in a distributed setting."""
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
......@@ -199,7 +204,7 @@ class LM(abc.ABC):
"To use this model with chat templates, please implement the 'tokenizer_name' property."
)
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
def chat_template(self, chat_template: bool | str = False) -> str | None:
"""Returns the chat template structure for user/assistant messages if a template is provided.
This method is intended to be overridden in a subclass to define a specific chat template format.
For models that do not support chat templates, this method returns None by default.
......@@ -207,7 +212,8 @@ class LM(abc.ABC):
return ""
def set_cache_hook(self, cache_hook: "CacheHook") -> None:
def set_cache_hook(self, cache_hook: CacheHook) -> None:
"""Sets the cache hook for the LM, which is used to cache responses from the LM."""
self.cache_hook = cache_hook
......@@ -218,14 +224,16 @@ def hash_args(attr: str, args: Iterable[Any]) -> str:
class CacheHook:
def __init__(self, cachinglm: Optional["CachingLM"]) -> None:
def __init__(self, cachinglm: CachingLM | None) -> None:
"""CacheHook is used to cache responses from the LM."""
if cachinglm is None:
self.dbdict: Optional["SqliteDict"] = None
self.dbdict: SqliteDict | None = None
return
self.dbdict = cachinglm.dbdict
def add_partial(self, attr: str, req: Iterable[Any], res: Any) -> None:
"""Adds a partial result to the cache."""
if self.dbdict is None:
return
hsh = hash_args(attr, req)
......@@ -258,7 +266,7 @@ class CachingLM:
eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
return lm_attr
def _fn(requests: list["Instance"]) -> list["Instance"]:
def _fn(requests: list[Instance]) -> list[Instance]:
res = []
remaining_reqs = []
warned = False
......@@ -290,11 +298,8 @@ class CachingLM:
eval_logger.info(
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
)
if remaining_reqs:
# actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs)
else:
rem_res = []
rem_res = getattr(self.lm, attr)(remaining_reqs) if remaining_reqs else []
# stick the new ones back into the list and also cache any of the new ones
resptr = 0
......@@ -313,7 +318,7 @@ class CachingLM:
return _fn
def get_cache_hook(self) -> "CacheHook":
def get_cache_hook(self) -> CacheHook:
return CacheHook(self)
......@@ -327,12 +332,13 @@ class TemplateLM(LM):
@property
@abc.abstractmethod
def eot_token_id(self):
def eot_token_id(self) -> int:
"""Returns the token ID for the end-of-text token (e.g., EOS)."""
pass
@property
def prefix_token_id(self):
# it is used as prefix for loglikelihood
def prefix_token_id(self) -> int:
"""Returns the token ID for the prefix token (e.g., BOS or EOS)."""
return self.eot_token_id
@abc.abstractmethod
......@@ -344,13 +350,33 @@ class TemplateLM(LM):
@abc.abstractmethod
def _loglikelihood_tokens(
self, requests: list["Instance"], **kwargs
self, requests: list[tuple[tuple[str, str], list[int], list[int]]], **kwargs
) -> list[tuple[float, bool]]:
"""Called by loglikelihood to compute log likelihoods for a list of requests.
Args:
requests: list[tuple[tuple[str, str], list[int], list[int]]]
A list of tuples where each tuple contains:
- (context, continuation) as a tuple of strings
- context_enc: list of token IDs for the context
- continuation_enc: list of token IDs for the continuation
Returns:
list[tuple[float, bool]]
A list of tuples where each tuple contains:
- logprob: float, the (summed) log probability of the continuation given the context
- isgreedy: bool, whether the continuation would be generated by greedy sampling from the context
See LM.loglikelihood for more details.
"""
pass
def _encode_pair(
self, context: str, continuation: str
) -> tuple[list[int], list[int]]:
"""Encodes a pair of context and continuation strings into token IDs.
We encode using encode(context+continuation) and then split into context and continuation.
"""
import transformers
n_spaces = len(context) - len(context.rstrip())
......@@ -373,8 +399,12 @@ class TemplateLM(LM):
return context_enc, continuation_enc
def loglikelihood(
self, requests: list["Instance"], disable_tqdm: bool = False
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context.
This calls `_loglikelihood_tokens` to compute the log likelihoods for a list of requests, after encoding.
"""
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
......@@ -394,14 +424,38 @@ class TemplateLM(LM):
def loglikelihood_rolling(
self, requests, disable_tqdm: bool = False
) -> list[float]:
"""Compute rolling log-likelihood of a sequence using non-overlapping windows.
See LM.loglikelihood_rolling for more details.
"""
pass
@abc.abstractmethod
def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
def generate_until(
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[str]:
"""Generate until a stopping sequence.
Args:
requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
context: str
Context string
gen_kwargs: dict
A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
Returns:
list[continuation, ...]
A list of model generated continuations.
continuation: str
The generated continuation.
See LM.generate_until for more details.
"""
pass
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
def chat_template(self, chat_template: bool | str = False) -> str | None:
"""
Assumes tokenizer has a chat_template attribute (self.tokenizer.chat_template: dict | str)
Set and get the appropriate chat template for the model.
This method sets the tokenizer's chat_template and returns the template string for reproducibility.
......
import logging
from typing import Callable, Dict, Union
"""Registry system for lm_eval components.
import evaluate as hf_evaluate
This module provides a centralized registration system for models, tasks, metrics,
filters, and other components in the lm_eval framework. The registry supports:
from lm_eval.api.model import LM
eval_logger = logging.getLogger(__name__)
- Lazy loading with placeholders to improve startup time
- Type checking and validation
- Thread-safe registration and lookup
- Plugin discovery via entry points
- Backwards compatibility with legacy registration patterns
MODEL_REGISTRY = {}
## Usage Examples
### Registering a Model
```python
from lm_eval.api.registry import register_model
from lm_eval.api.model import LM
def register_model(*names):
# either pass a list or a single alias.
# function receives them as a tuple of strings
@register_model("my-model")
class MyModel(LM):
def __init__(self, **kwargs):
...
```
### Registering a Metric
```python
from lm_eval.api.registry import register_metric
@register_metric(
metric="my_accuracy",
aggregation="mean",
higher_is_better=True
)
def my_accuracy_fn(items):
...
```
### Registering with Lazy Loading
```python
# Register without importing the actual implementation
model_registry.register("lazy-model", lazy="my_package.models:LazyModel")
```
### Looking up Components
```python
from lm_eval.api.registry import get_model, get_metric
# Get a model class
model_cls = get_model("gpt-j")
model = model_cls(**config)
# Get a metric function
metric_fn = get_metric("accuracy")
```
"""
from __future__ import annotations
import importlib
import inspect
import threading
from collections.abc import Iterable
from dataclasses import dataclass
from functools import lru_cache
from types import MappingProxyType
from typing import Any, Callable, Generic, TypeVar, Union, cast
from lm_eval.api.filter import Filter
try:
import importlib.metadata as md # Python ≥3.10
except ImportError: # pragma: no cover – fallback for 3.8/3.9
import importlib_metadata as md # type: ignore
LEGACY_EXPORTS = [
"DEFAULT_METRIC_REGISTRY",
"AGGREGATION_REGISTRY",
"register_model",
"get_model",
"register_task",
"get_task",
"register_metric",
"get_metric",
"register_metric_aggregation",
"get_metric_aggregation",
"register_higher_is_better",
"is_higher_better",
"register_filter",
"get_filter",
"register_aggregation",
"get_aggregation",
"MODEL_REGISTRY",
"TASK_REGISTRY",
"METRIC_REGISTRY",
"METRIC_AGGREGATION_REGISTRY",
"HIGHER_IS_BETTER_REGISTRY",
"FILTER_REGISTRY",
]
__all__ = [
# canonical
"Registry",
"MetricSpec",
"model_registry",
"task_registry",
"metric_registry",
"metric_agg_registry",
"higher_is_better_registry",
"filter_registry",
"freeze_all",
*LEGACY_EXPORTS,
] # type: ignore
T = TypeVar("T")
Placeholder = Union[str, md.EntryPoint]
@lru_cache(maxsize=16)
def _materialise_placeholder(ph: Placeholder) -> Any:
"""Materialize a lazy placeholder into the actual object.
This is at module level to avoid memory leaks from lru_cache on instance methods.
Args:
ph: Either a string path "module:object" or an EntryPoint instance
Returns:
The loaded object
Raises:
ValueError: If the string format is invalid
ImportError: If the module cannot be imported
AttributeError: If the object doesn't exist in the module
"""
if isinstance(ph, str):
mod, _, attr = ph.partition(":")
if not attr:
raise ValueError(f"Invalid lazy path '{ph}', expected 'module:object'")
return getattr(importlib.import_module(mod), attr)
return ph.load()
# Metric-specific metadata storage --------------------------------------------
_metric_meta: dict[str, dict[str, Any]] = {}
class Registry(Generic[T]):
"""A thread-safe registry for named objects with lazy loading support.
The Registry provides a central location for registering and retrieving
components by name. It supports:
- Direct registration of objects
- Lazy registration with placeholders (strings or entry points)
- Type checking against a base class
- Thread-safe operations
- Freezing to prevent further modifications
Example:
>>> from lm_eval.api.model import LM
>>> registry = Registry("models", base_cls=LM)
>>>
>>> # Direct registration
>>> @registry.register("my-model")
>>> class MyModel(LM):
... pass
>>>
>>> # Lazy registration
>>> registry.register("lazy-model", lazy="mypackage:LazyModel")
>>>
>>> # Retrieval (triggers lazy loading if needed)
>>> model_cls = registry.get("my-model")
>>> model = model_cls()
"""
def __init__(
self,
name: str,
*,
base_cls: type[T] | None = None,
) -> None:
"""Initialize a new registry.
Args:
name: Human-readable name for error messages (e.g., "model", "metric")
base_cls: Optional base class that all registered objects must inherit from
"""
self._name = name
self._base_cls = base_cls
self._objs: dict[str, T | Placeholder] = {}
self._lock = threading.RLock()
# Registration (decorator or direct call) --------------------------------------
def register(
self,
*aliases: str,
lazy: T | Placeholder | None = None,
) -> Callable[[T], T]:
"""Register an object under one or more aliases.
Can be used as a decorator or called directly for lazy registration.
Args:
*aliases: Names to register the object under. If empty, uses object's __name__
lazy: For direct calls only - a placeholder string "module:object" or EntryPoint
Returns:
Decorator function (or no-op if lazy registration)
Examples:
>>> # As decorator
>>> @model_registry.register("name1", "name2")
>>> class MyModel(LM):
... pass
>>>
>>> # Direct lazy registration
>>> model_registry.register("lazy-name", lazy="mymodule:MyModel")
Raises:
ValueError: If alias already registered with different target
TypeError: If object doesn't inherit from base_cls (when specified)
"""
def _store(alias: str, target: T | Placeholder) -> None:
current = self._objs.get(alias)
# collision handling ------------------------------------------
if current is not None and current != target:
# allow placeholder → real object upgrade
if isinstance(current, str) and isinstance(target, type):
# mod, _, cls = current.partition(":")
if current == f"{target.__module__}:{target.__name__}":
self._objs[alias] = target
return
raise ValueError(
f"{self._name!r} alias '{alias}' already registered ("
f"existing={current}, new={target})"
)
# type check for concrete classes ----------------------------------------------
if self._base_cls is not None and isinstance(target, type):
if not issubclass(target, self._base_cls): # type: ignore[arg-type]
raise TypeError(
f"{target} must inherit from {self._base_cls} to be a {self._name}"
)
self._objs[alias] = target
def decorator(obj: T) -> T: # type: ignore[valid-type]
names = aliases or (getattr(obj, "__name__", str(obj)),)
with self._lock:
for name in names:
_store(name, obj)
return obj
# Direct call with *lazy* placeholder
if lazy is not None:
if len(aliases) != 1:
raise ValueError("Exactly one alias required when using 'lazy='")
with self._lock:
_store(aliases[0], lazy) # type: ignore[arg-type]
# return no‑op decorator for accidental use
return lambda x: x # type: ignore[return-value]
return decorator
# Lookup & materialisation --------------------------------------------------
def _materialise(self, ph: Placeholder) -> T:
"""Materialize a placeholder using the module-level cached function.
Args:
ph: Placeholder to materialize
Returns:
The materialized object, cast to type T
"""
return cast(T, _materialise_placeholder(ph))
def get(self, alias: str) -> T:
"""Retrieve an object by alias, materializing if needed.
Thread-safe lazy loading: if the alias points to a placeholder,
it will be loaded and cached before returning.
Args:
alias: The registered name to look up
Returns:
The registered object
Raises:
KeyError: If alias not found
TypeError: If materialized object doesn't match base_cls
ImportError/AttributeError: If lazy loading fails
"""
try:
target = self._objs[alias]
except KeyError as exc:
raise KeyError(
f"Unknown {self._name} '{alias}'. Available: {', '.join(self._objs)}"
) from exc
if isinstance(target, (str, md.EntryPoint)):
with self._lock:
# Re‑check under lock (another thread might have resolved it)
fresh = self._objs[alias]
if isinstance(fresh, (str, md.EntryPoint)):
concrete = self._materialise(fresh)
# Only update if not frozen (MappingProxyType)
if not isinstance(self._objs, MappingProxyType):
self._objs[alias] = concrete
else:
concrete = fresh # another thread did the job
target = concrete
def decorate(cls):
for name in names:
assert issubclass(cls, LM), (
f"Model '{name}' ({cls.__name__}) must extend LM class"
# Late type/validator checks
if self._base_cls is not None and not issubclass(target, self._base_cls): # type: ignore[arg-type]
raise TypeError(
f"{target} does not inherit from {self._base_cls} (alias '{alias}')"
)
return target
assert name not in MODEL_REGISTRY, (
f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
)
def __getitem__(self, alias: str) -> T:
"""Allow dict-style access: registry[alias]."""
return self.get(alias)
MODEL_REGISTRY[name] = cls
return cls
def __iter__(self):
"""Iterate over registered aliases."""
return iter(self._objs)
return decorate
def __len__(self):
"""Return number of registered aliases."""
return len(self._objs)
def items(self):
"""Return (alias, object) pairs.
def get_model(model_name):
try:
return MODEL_REGISTRY[model_name]
except KeyError:
raise ValueError(
f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}"
)
Note: Objects may be placeholders that haven't been materialized yet.
"""
return self._objs.items()
TASK_REGISTRY = {}
GROUP_REGISTRY = {}
ALL_TASKS = set()
func2task_index = {}
# Utilities -------------------------------------------------------------
def origin(self, alias: str) -> str | None:
"""Get the source location of a registered object.
def register_task(name):
def decorate(fn):
assert name not in TASK_REGISTRY, (
f"task named '{name}' conflicts with existing registered task!"
)
Args:
alias: The registered name
TASK_REGISTRY[name] = fn
ALL_TASKS.add(name)
func2task_index[fn.__name__] = name
return fn
Returns:
"path/to/file.py:line_number" or None if not available
"""
obj = self._objs.get(alias)
if isinstance(obj, (str, md.EntryPoint)):
return None
try:
path = inspect.getfile(obj) # type: ignore[arg-type]
line = inspect.getsourcelines(obj)[1] # type: ignore[arg-type]
return f"{path}:{line}"
except Exception: # pragma: no cover – best‑effort only
return None
return decorate
def freeze(self):
"""Make the registry read-only to prevent further modifications.
After freezing, attempts to register new objects will fail.
This is useful for ensuring registry contents don't change after
initialization.
"""
with self._lock:
self._objs = MappingProxyType(dict(self._objs)) # type: ignore[assignment]
def register_group(name):
def decorate(fn):
func_name = func2task_index[fn.__name__]
if name in GROUP_REGISTRY:
GROUP_REGISTRY[name].append(func_name)
else:
GROUP_REGISTRY[name] = [func_name]
ALL_TASKS.add(name)
return fn
# Test helper --------------------------------
def _clear(self): # pragma: no cover
"""Erase registry (for isolated tests).
return decorate
OUTPUT_TYPE_REGISTRY = {}
METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {}
HIGHER_IS_BETTER_REGISTRY = {}
FILTER_REGISTRY = {}
DEFAULT_METRIC_REGISTRY = {
"loglikelihood": [
"perplexity",
"acc",
],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": ["acc", "acc_norm"],
"generate_until": ["exact_match"],
}
def register_metric(**args):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn):
assert "metric" in args
name = args["metric"]
for key, registry in [
("metric", METRIC_REGISTRY),
("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
("aggregation", METRIC_AGGREGATION_REGISTRY),
]:
if key in args:
value = args[key]
assert value not in registry, (
f"{key} named '{value}' conflicts with existing registered {key}!"
)
Clears both the registry contents and the materialization cache.
Only use this in test code to ensure clean state between tests.
"""
self._objs.clear()
_materialise_placeholder.cache_clear()
if key == "metric":
registry[name] = fn
elif key == "aggregation":
registry[name] = AGGREGATION_REGISTRY[value]
else:
registry[name] = value
# Structured object for metrics ------------------
@dataclass(frozen=True)
class MetricSpec:
"""Specification for a metric including computation and aggregation functions.
Attributes:
compute: Function to compute metric on individual items
aggregate: Function to aggregate multiple metric values into a single score
higher_is_better: Whether higher values indicate better performance
output_type: Optional type hint for the output (e.g., "generate_until" for perplexity)
requires: Optional list of other metrics this one depends on
"""
compute: Callable[[Any, Any], Any]
aggregate: Callable[[Iterable[Any]], float]
higher_is_better: bool = True
output_type: str | None = None
requires: list[str] | None = None
# Canonical registries aliases ---------------------
from lm_eval.api.model import LM # noqa: E402
model_registry: Registry[type[LM]] = cast(
Registry[type[LM]], Registry("model", base_cls=LM)
)
task_registry: Registry[Callable[..., Any]] = Registry("task")
metric_registry: Registry[MetricSpec] = Registry("metric")
metric_agg_registry: Registry[Callable[[Iterable[Any]], float]] = Registry(
"metric aggregation"
)
higher_is_better_registry: Registry[bool] = Registry("higher‑is‑better flag")
filter_registry: Registry[type[Filter]] = Registry("filter")
# Public helper aliases ------------------------------------------------------
register_model = model_registry.register
get_model = model_registry.get
register_task = task_registry.register
get_task = task_registry.get
register_filter = filter_registry.register
get_filter = filter_registry.get
# Metric helpers need thin wrappers to build MetricSpec ----------------------
def _no_aggregation_fn(values: Iterable[Any]) -> float:
"""Default aggregation that raises NotImplementedError.
Args:
values: Metric values to aggregate (unused)
Raises:
NotImplementedError: Always - this is a placeholder for metrics
that haven't specified an aggregation function
"""
raise NotImplementedError(
"No aggregation function specified for this metric. "
"Please specify 'aggregation' parameter in @register_metric."
)
def register_metric(**kw):
"""Decorator for registering metric functions.
Creates a MetricSpec from the decorated function and keyword arguments,
then registers it in the metric registry.
Args:
**kw: Keyword arguments including:
- metric: Name to register the metric under (required)
- aggregation: Name of aggregation function in metric_agg_registry
- higher_is_better: Whether higher scores are better (default: True)
- output_type: Optional output type hint
- requires: Optional list of required metrics
Returns:
Decorator function that registers the metric
Example:
>>> @register_metric(
... metric="my_accuracy",
... aggregation="mean",
... higher_is_better=True
... )
... def compute_accuracy(items):
... return sum(item["correct"] for item in items) / len(items)
"""
name = kw["metric"]
def deco(fn):
spec = MetricSpec(
compute=fn,
aggregate=(
metric_agg_registry.get(kw["aggregation"])
if "aggregation" in kw
else _no_aggregation_fn
),
higher_is_better=kw.get("higher_is_better", True),
output_type=kw.get("output_type"),
requires=kw.get("requires"),
)
metric_registry.register(name, lazy=spec)
_metric_meta[name] = kw
higher_is_better_registry.register(name, lazy=spec.higher_is_better)
return fn
return decorate
return deco
def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
if not hf_evaluate_metric:
if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name]
else:
eval_logger.warning(
f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
)
def get_metric(name, hf_evaluate_metric=False):
"""Get a metric compute function by name.
try:
metric_object = hf_evaluate.load(name)
return metric_object.compute
except Exception:
eval_logger.error(
f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
)
First checks the local metric registry, then optionally falls back
to HuggingFace evaluate library.
Args:
name: Metric name to retrieve
hf_evaluate_metric: If True, suppress warning when falling back to HF
def register_aggregation(name: str):
def decorate(fn):
assert name not in AGGREGATION_REGISTRY, (
f"aggregation named '{name}' conflicts with existing registered aggregation!"
)
Returns:
The metric's compute function
AGGREGATION_REGISTRY[name] = fn
return fn
Raises:
KeyError: If metric not found in registry or HF evaluate
"""
try:
spec = metric_registry.get(name)
return spec.compute # type: ignore[attr-defined]
except KeyError:
if not hf_evaluate_metric:
import logging
return decorate
logging.getLogger(__name__).warning(
f"Metric '{name}' not in registry; trying HF evaluate…"
)
try:
import evaluate as hf
return hf.load(name).compute # type: ignore[attr-defined]
except Exception:
raise KeyError(f"Metric '{name}' not found anywhere")
def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} not a registered aggregation metric!")
register_metric_aggregation = metric_agg_registry.register
get_metric_aggregation = metric_agg_registry.get
def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
try:
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
register_higher_is_better = higher_is_better_registry.register
is_higher_better = higher_is_better_registry.get
# Legacy compatibility
register_aggregation = metric_agg_registry.register
get_aggregation = metric_agg_registry.get
DEFAULT_METRIC_REGISTRY = metric_registry
AGGREGATION_REGISTRY = metric_agg_registry
def is_higher_better(metric_name) -> bool:
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
eval_logger.warning(
f"higher_is_better not specified for metric '{metric_name}'!"
)
def freeze_all():
"""Freeze all registries to prevent further modifications.
def register_filter(name):
def decorate(cls):
if name in FILTER_REGISTRY:
eval_logger.info(
f"Registering filter `{name}` that is already in Registry {FILTER_REGISTRY}"
)
FILTER_REGISTRY[name] = cls
return cls
This is useful for ensuring registry contents are immutable after
initialization, preventing accidental modifications during runtime.
"""
for r in (
model_registry,
task_registry,
metric_registry,
metric_agg_registry,
higher_is_better_registry,
filter_registry,
):
r.freeze()
return decorate
# Backwards‑compat aliases ----------------------------------------
def get_filter(filter_name: Union[str, Callable]) -> Callable:
try:
return FILTER_REGISTRY[filter_name]
except KeyError as e:
if callable(filter_name):
return filter_name
else:
eval_logger.warning(f"filter `{filter_name}` is not registered!")
raise e
MODEL_REGISTRY = model_registry
TASK_REGISTRY = task_registry
METRIC_REGISTRY = metric_registry
METRIC_AGGREGATION_REGISTRY = metric_agg_registry
HIGHER_IS_BETTER_REGISTRY = higher_is_better_registry
FILTER_REGISTRY = filter_registry
from __future__ import annotations
import logging
import warnings
from collections.abc import Iterable, Sequence
from functools import partial
from typing import TYPE_CHECKING, Iterable, Optional, Union
from typing import TYPE_CHECKING, Any
import datasets
......@@ -18,9 +21,9 @@ class ContextSampler:
def __init__(
self,
docs: list[dict],
task: Union["Task", "ConfigurableTask"],
fewshot_indices: Optional[Iterable] = None,
rnd: Optional["Random"] = None,
task: Task | ConfigurableTask,
fewshot_indices: Iterable | None = None,
rnd: Random | None = None,
) -> None:
self.rnd = rnd
if not self.rnd:
......@@ -75,7 +78,7 @@ class ContextSampler:
)
self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None):
def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str | None = None):
# draw an extra fewshot sample if using same split as evaluating on
prefix = gen_prefix + " " if gen_prefix else ""
n_samples = (
......@@ -95,10 +98,13 @@ class ContextSampler:
for doc in selected_docs:
doc_content = self.doc_to_text(doc)
doc_target = self.doc_to_target(doc)
if self.config.doc_to_choice is None or isinstance(doc_content, str):
if (
self.config.doc_to_choice is None and isinstance(doc_content, str)
) or isinstance(doc_content, str):
labeled_examples += doc_content
else:
labeled_examples += self.doc_to_choice(doc)[doc_content]
if isinstance(doc_content, int):
labeled_examples += self.doc_to_choice(doc)[doc_content]
if doc_target != "":
if self.target_delimiter.isspace() and str(doc_target)[0].isspace():
......@@ -126,7 +132,7 @@ class ContextSampler:
doc: dict,
num_fewshot: int,
fewshot_as_multiturn: bool = False,
gen_prefix: Optional[str] = None,
gen_prefix: str | None = None,
):
# TODO: Do we need any other delimiter
prefix = gen_prefix + " " if gen_prefix else ""
......@@ -181,16 +187,22 @@ class ContextSampler:
return chat_history
def sample(self, n: int):
# @classmethod
# def from_fewshot_dfg(cls, cfg: FewshotConfig):
# if not
def sample(self, n: int) -> Sequence[dict]:
"""
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
"""
assert self.rnd is not None, (
"Error: `rnd` must be set to a random.Random instance before sampling."
)
return self.rnd.sample(self.docs, n)
class FirstNSampler(ContextSampler):
def sample(self, n: int) -> None:
def sample(self, n: int) -> Sequence[dict[str, Any]]:
"""
Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
......@@ -202,22 +214,22 @@ class FirstNSampler(ContextSampler):
class BalancedSampler(ContextSampler):
def sample(self, n: int) -> None:
def sample(self, n: int):
"""
TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random?
"""
pass
raise NotImplementedError
class ManualSampler(ContextSampler):
def sample(self, n: int) -> None:
def sample(self, n: int):
""" """
pass
raise NotImplementedError
SAMPLER_REGISTRY = {
SAMPLER_REGISTRY: dict[str, type[ContextSampler]] = {
"default": ContextSampler,
"first_n": FirstNSampler,
}
......@@ -226,7 +238,7 @@ SAMPLER_REGISTRY = {
def get_sampler(name: str):
try:
return SAMPLER_REGISTRY[name]
except KeyError:
raise ValueError(
except KeyError as e:
raise KeyError(
f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
)
) from e
from __future__ import annotations
import abc
import ast
import logging
......@@ -5,36 +7,22 @@ import random
import re
from collections.abc import Callable, Iterable, Iterator, Mapping
from copy import deepcopy
from dataclasses import asdict, dataclass
from inspect import getsource
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
)
from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal, overload
import datasets
import numpy as np
from tqdm import tqdm
from typing_extensions import deprecated
from lm_eval import utils
from lm_eval.api import samplers
from lm_eval.api.instance import Instance, OutputType
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
get_aggregation,
get_metric,
get_metric_aggregation,
is_higher_better,
)
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from lm_eval.api.utils import check_gold_index_error
from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.config.metric import MetricConfig
from lm_eval.config.task import DataSet, TaskConfig
from lm_eval.filters import build_filter_ensemble
from lm_eval.prompts import get_prompt
ALL_OUTPUT_TYPES = [
......@@ -44,139 +32,11 @@ ALL_OUTPUT_TYPES = [
"generate_until",
]
eval_logger = logging.getLogger(__name__)
@dataclass
class TaskConfig(dict):
# task naming/registry
task: Optional[str] = None
task_alias: Optional[str] = None
tag: Optional[Union[str, list]] = None
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
custom_dataset: Optional[Callable] = None
dataset_path: Optional[str] = None
dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = None
training_split: Optional[str] = None
validation_split: Optional[str] = None
test_split: Optional[str] = None
fewshot_split: Optional[str] = (
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
)
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
process_docs: Optional[Callable] = None
doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None
doc_to_image: Union[Callable, str] = None
doc_to_audio: Union[Callable, str] = None
unsafe_code: bool = False
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None
use_prompt: Optional[str] = None
description: str = ""
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
fewshot_config: Optional[dict] = None
# runtime configuration options
num_fewshot: Optional[int] = None
# scoring options
metric_list: Optional[list] = None
output_type: OutputType = "generate_until"
generation_kwargs: Optional[dict] = None
repeats: int = 1
filter_list: Optional[Union[str, list]] = None
should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None
gen_prefix: Optional[str] = None
metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
)
def __post_init__(self) -> None:
if self.generation_kwargs is not None:
if self.output_type != "generate_until":
eval_logger.warning(
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
)
if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float(
self.generation_kwargs["temperature"]
)
if "until" not in self.generation_kwargs:
eval_logger.warning(
f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={self.fewshot_delimiter!r}"
)
self.generation_kwargs["until"] = [self.fewshot_delimiter]
else:
if self.output_type == "generate_until":
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {
"until": (
None
if self.fewshot_delimiter is None
else [self.fewshot_delimiter]
),
"do_sample": False,
"temperature": 0,
}
eval_logger.warning(
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
)
def __getitem__(self, item):
return getattr(self, item)
if TYPE_CHECKING:
pass
def __setitem__(self, item, value):
return setattr(self, item, value)
def to_dict(self, keep_callable: bool = False) -> dict:
"""dumps the current config as a dictionary object, as a printable format.
null fields will not be printed.
Used for dumping results alongside full task configuration
:return: dict
A printable dictionary version of the TaskConfig object.
# TODO: should any default value in the TaskConfig not be printed?
"""
cfg_dict = asdict(self)
# remove values that are `None`
for k, v in list(cfg_dict.items()):
if v is None:
cfg_dict.pop(k)
elif k == "metric_list":
for metric_dict in v:
for metric_key, metric_value in metric_dict.items():
if callable(metric_value):
metric_dict[metric_key] = self.serialize_function(
metric_value, keep_callable=keep_callable
)
cfg_dict[k] = v
elif callable(v):
cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
return cfg_dict
def serialize_function(
self, value: Union[Callable, str], keep_callable=False
) -> Union[Callable, str]:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
Otherwise, attempts to return the source code of the callable using 'getsource'.
"""
if keep_callable:
return value
else:
try:
return getsource(value)
except (TypeError, OSError):
return str(value)
eval_logger = logging.getLogger(__name__)
class Task(abc.ABC):
......@@ -189,23 +49,23 @@ class Task(abc.ABC):
{"question": ..., question, answer)
"""
VERSION: Optional[Union[int, str]] = None
VERSION: int | str | None = None
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
# or a path to a custom `datasets` loading script.
DATASET_PATH: Optional[str] = None
DATASET_PATH: str | None = None
# The name of a subset within `DATASET_PATH`.
DATASET_NAME: Optional[str] = None
DATASET_NAME: str | None = None
OUTPUT_TYPE: Optional[OutputType] = None
OUTPUT_TYPE: OutputType | None = None
def __init__(
self,
data_dir: Optional[str] = None,
cache_dir: Optional[str] = None,
download_mode: Optional[datasets.DownloadMode] = None,
config: Optional[Mapping] = None, # Union[dict, TaskConfig]
data_dir: str | None = None,
cache_dir: str | None = None,
download_mode: datasets.DownloadMode | None = None,
config: Mapping | None = None, # Union[dict, TaskConfig]
) -> None:
"""
:param data_dir: str
......@@ -229,21 +89,21 @@ class Task(abc.ABC):
Fresh download and fresh dataset.
"""
self.download(data_dir, cache_dir, download_mode)
self._training_docs: Optional[list] = None
self._fewshot_docs: Optional[list] = None
self._instances: Optional[List[Instance]] = None
self._training_docs: list | None = None
self._fewshot_docs: list | None = None
self._instances: list[Instance] | None = None
self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
self._config: TaskConfig = TaskConfig.from_yaml({**config})
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
self.fewshot_rnd: Optional[random.Random] = (
self._filters = [build_filter_ensemble("none", [("take_first", None)])]
self.fewshot_rnd: random.Random | None = (
None # purposely induce errors in case of improper usage
)
def download(
self,
data_dir: Optional[str] = None,
cache_dir: Optional[str] = None,
data_dir: str | None = None,
cache_dir: str | None = None,
download_mode=None,
) -> None:
"""Downloads and returns the task dataset.
......@@ -270,6 +130,7 @@ class Task(abc.ABC):
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset.
"""
assert self.DATASET_PATH is not None, "DATASET_PATH must be set in Task class"
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
......@@ -283,50 +144,53 @@ class Task(abc.ABC):
"""Returns the TaskConfig associated with this class."""
return self._config
@abc.abstractmethod
def has_training_docs(self):
@property
def has_training_docs(self) -> bool:
"""Whether the task has a training set"""
raise NotImplementedError
@abc.abstractmethod
def has_validation_docs(self):
@property
def has_validation_docs(self) -> bool:
"""Whether the task has a validation set"""
raise NotImplementedError
@abc.abstractmethod
def has_test_docs(self):
@property
def has_test_docs(self) -> bool:
"""Whether the task has a test set"""
raise NotImplementedError
def training_docs(self) -> Iterable:
def training_docs(self) -> DataSet | None:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def validation_docs(self) -> Iterable:
def validation_docs(self) -> DataSet | None:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def test_docs(self) -> Iterable:
def test_docs(self) -> DataSet | None:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def fewshot_docs(self) -> Iterable:
def fewshot_docs(self) -> DataSet | None:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
if self.has_training_docs():
if self.has_training_docs:
return self.training_docs()
elif self.has_validation_docs():
elif self.has_validation_docs:
return self.validation_docs()
else:
if self.config.get("num_fewshot", 0) > 0:
if self.config.num_fewshot and self.config.num_fewshot > 0:
eval_logger.warning(
f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False"
", using test_docs as fewshot_docs but this is not recommended."
......@@ -345,54 +209,54 @@ class Task(abc.ABC):
return doc
@property
def instances(self) -> List[Instance]:
def instances(self) -> list[Instance]:
"""After calling `task.build_all_requests()`, tasks
maintain a list of the dataset instances which will be evaluated.
"""
return self._instances
def fewshot_examples(self, k, rnd):
def fewshot_examples(self, k: int, rnd) -> Iterable[dict]:
if self._training_docs is None:
self._training_docs = list(self.training_docs())
return rnd.sample(self._training_docs, k)
def doc_to_decontamination_query(self, doc):
def doc_to_decontamination_query(self, doc: dict):
raise NotImplementedError(
"Override doc_to_decontamination_query with document specific decontamination query."
)
@abc.abstractmethod
def doc_to_text(self, doc):
def doc_to_text(self, doc: dict) -> str:
pass
@abc.abstractmethod
def doc_to_target(self, doc):
def doc_to_target(self, doc: dict) -> str | int:
pass
# not an abstractmethod because not every language-only task has to implement this
def doc_to_image(self, doc):
def doc_to_image(self, doc: dict):
raise NotImplementedError
def doc_to_audio(self, doc):
def doc_to_audio(self, doc: dict):
raise NotImplementedError
def doc_to_prefix(self, doc):
def doc_to_prefix(self, doc: dict) -> str:
return ""
def build_all_requests(
self,
*,
limit: Union[int, None] = None,
samples: Optional[List[int]] = None,
limit: int | None = None,
samples: list[int] | None = None,
rank: int = 0,
world_size: int = 1,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
system_instruction: Optional[str] = None,
system_instruction: str | None = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
chat_template: Callable | None = None,
tokenizer_name: str = "",
) -> None:
"""Build a set of Instances for a task, and store them in task.instances"""
......@@ -465,7 +329,7 @@ class Task(abc.ABC):
inst = self.construct_requests(
doc=doc,
ctx=fewshot_ctx,
metadata=(self.config["task"], doc_id, self.config.repeats),
metadata=(self.config.task, doc_id, self.config.repeats),
apply_chat_template=apply_chat_template,
chat_template=chat_template,
)
......@@ -494,7 +358,7 @@ class Task(abc.ABC):
save_to_cache(file_name=cache_key, obj=instances)
@abc.abstractmethod
def construct_requests(self, doc, ctx, **kwargs):
def construct_requests(self, doc: dict, ctx: list[dict] | str, **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
......@@ -514,7 +378,7 @@ class Task(abc.ABC):
"""
@abc.abstractmethod
def process_results(self, doc, results):
def process_results(self, doc: dict, results: list) -> dict[str, Any]:
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
......@@ -524,33 +388,36 @@ class Task(abc.ABC):
:param results:
The results of the requests created in construct_requests.
"""
raise NotImplementedError
@abc.abstractmethod
@deprecated("not used anymore")
def aggregation(self):
"""
:returns: {str: [metric_score] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metric scores
"""
return True
@abc.abstractmethod
@deprecated("not used anymore")
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return True
def get_config(self, key: str) -> Any:
return getattr(self._config, key, None)
@classmethod
def count_bytes(cls, doc):
def count_bytes(cls, doc: str) -> int:
"""Used for byte-level perplexity metrics in rolling loglikelihood"""
return len(doc.encode("utf-8"))
@classmethod
def count_words(cls, doc):
def count_words(cls, doc: str) -> int:
"""Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc))
......@@ -585,13 +452,13 @@ class Task(abc.ABC):
labeled_examples = ""
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
if self.has_training_docs:
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
if self.has_validation_docs
else self.test_docs()
)
......@@ -613,13 +480,15 @@ class Task(abc.ABC):
example = self.doc_to_text(doc)
return description + labeled_examples + example
def apply_filters(self) -> Optional[List[Instance]]:
def apply_filters(self) -> list[Instance] | None:
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"):
if hasattr(self, "_filters") and self._instances:
for f in self._filters:
f.apply(self._instances)
else:
eval_logger.warning("No filter defined, passing through instances")
eval_logger.warning(
"No filter defined or no instances, passing through instances"
)
return self._instances
def dump_config(self) -> dict:
......@@ -630,9 +499,6 @@ class Task(abc.ABC):
def set_config(self, key: str, value: Any, update: bool = False) -> None:
"""Set or update the configuration for a given key."""
if key is None:
raise ValueError("Key must be provided.")
if update:
current_value = getattr(self._config, key, {})
if not isinstance(current_value, dict):
......@@ -650,34 +516,24 @@ class Task(abc.ABC):
Parameters:
- metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
"""
(
self._metric_fn_list,
self._aggregation_list,
self._metric_fn_kwargs,
self._higher_is_better,
) = ({}, {}, {}, {})
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
self._metric_fn_kwargs[metric_name] = {}
if not isinstance(self, ConfigurableTask):
self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
self.aggregation = lambda: {
metric_name: get_metric_aggregation(metric_name)
}
self._config.metric_list = [{"metric": metric_name}]
self._config.process_results = None
def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
# if not isinstance(self, ConfigurableTask):
# self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
# self.aggregation = lambda: {
# metric_name: get_metric_aggregation(metric_name)
# }
self._config.metric_list = [MetricConfig(name=metric_name)]
self._config.process_results = lambda *args: {"bypass": 0}
def set_fewshot_seed(self, seed: int | None = None) -> None:
self.fewshot_rnd = random.Random(seed)
if hasattr(self, "sampler"):
self.sampler.rnd = self.fewshot_rnd
@property
def eval_docs(self) -> Union[datasets.Dataset, List[dict]]:
if self.has_test_docs():
def eval_docs(self) -> datasets.Dataset | Iterable[dict]:
if self.has_test_docs:
return self.test_docs()
elif self.has_validation_docs():
elif self.has_validation_docs:
return self.validation_docs()
else:
raise ValueError(
......@@ -688,13 +544,13 @@ class Task(abc.ABC):
self,
*,
rank: int = 0,
limit: Union[int, None] = None,
limit: int | None = None,
world_size: int = 1,
samples: Optional[List[int]] = None,
) -> Iterator[Tuple[int, Any]]:
samples: list[int] | None = None,
) -> Iterator[tuple[int, Any]]:
if samples:
n = len(self.eval_docs)
assert all([e < n for e in samples]), (
assert all(e < n for e in samples), (
f"Elements of --samples should be in the interval [0,k-1] where k is the number of total examples. In this case, k={n}."
)
eval_logger.info(
......@@ -727,14 +583,14 @@ class ConfigurableTask(Task):
data_dir=None,
cache_dir=None,
download_mode=None,
config: Optional[dict] = None,
) -> None: # TODO no super() call here
config: Mapping[str, Any] | None = None,
) -> None:
# Get pre-configured attributes
self._config = self.CONFIG
# Use new configurations if there was no preconfiguration
if self.config is None:
self._config = TaskConfig(**config)
self._config = TaskConfig.from_yaml(config)
# Overwrite configs
else:
if config is not None:
......@@ -745,9 +601,8 @@ class ConfigurableTask(Task):
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
)
if isinstance(self.config.metadata, dict):
if "version" in self.config.metadata:
self.VERSION = self.config.metadata["version"]
if isinstance(self.config.metadata, dict) and "version" in self.config.metadata:
self.VERSION = self.config.metadata["version"]
if self.config.output_type is not None:
if self.config.output_type not in ALL_OUTPUT_TYPES:
......@@ -773,294 +628,132 @@ class ConfigurableTask(Task):
if self.config.dataset_name is not None:
self.DATASET_NAME = self.config.dataset_name
self._metric_fn_list = {}
self._metric_fn_kwargs = {}
self._aggregation_list = {}
self._higher_is_better = {}
if self.config.metric_list is None:
# TODO: handle this in TaskConfig.__post_init__ ?
_metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
for metric_name in _metric_list:
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._metric_fn_kwargs[metric_name] = {}
self._aggregation_list[metric_name] = get_metric_aggregation(
metric_name
)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
else:
for metric_config in self.config.metric_list:
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name = metric_config["metric"]
kwargs = {
key: metric_config[key]
for key in metric_config
if key
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
}
hf_evaluate_metric = (
"hf_evaluate" in metric_config
and metric_config["hf_evaluate"] is True
)
if self.config.process_results is not None:
self._metric_fn_list[metric_name] = None
self._metric_fn_kwargs[metric_name] = {}
elif callable(metric_name):
metric_fn = metric_name.__call__
metric_name = metric_name.__name__
self._metric_fn_list[metric_name] = metric_fn
self._metric_fn_kwargs[metric_name] = kwargs
else:
self._metric_fn_list[metric_name] = get_metric(
metric_name, hf_evaluate_metric
)
self._metric_fn_kwargs[metric_name] = kwargs
if "aggregation" in metric_config:
agg_name = metric_config["aggregation"]
if isinstance(agg_name, str):
self._aggregation_list[metric_name] = get_aggregation(agg_name)
elif callable(agg_name):
self._aggregation_list[metric_name] = metric_config[
"aggregation"
]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = get_metric_aggregation(metric_name)
eval_logger.warning(
f"[Task: {self.config.task}] metric {metric_name} is defined, but aggregation is not. "
f"using default "
f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
)
self._aggregation_list[metric_name] = metric_agg
if "higher_is_better" in metric_config:
self._higher_is_better[metric_name] = metric_config[
"higher_is_better"
]
else:
eval_logger.warning(
f"[Task: {self.config.task}] metric {metric_name} is defined, but higher_is_better is not. "
f"using default "
f"higher_is_better={is_higher_better(metric_name)}"
)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
# self.metric_list: list[MetricConfig] = self.config.get_metrics
self.download(self.config.dataset_kwargs)
self._training_docs = None
self._fewshot_docs = None
if self.config.filter_list is not None:
self._filters = []
for filter_config in self.config.filter_list:
filter_name = filter_config["name"]
filter_functions = filter_config["filter"]
components = []
for function in filter_functions:
kwargs = {
key: function[key] for key in function if key != "function"
}
components.append([function["function"], kwargs])
filter_pipeline = build_filter_ensemble(filter_name, components)
self._filters.append(filter_pipeline)
else:
# TODO: handle repeats in a more general way rather than just discarding
eval_logger.debug(
"No custom filters defined. Using default 'take_first' filter for handling repeats."
)
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
if self.config.use_prompt is not None:
eval_logger.info(f"loading prompt {self.config.use_prompt}")
self.prompt = get_prompt(
self.config.use_prompt, self.DATASET_PATH, self.DATASET_NAME
)
else:
self.prompt = None
if self.fewshot_docs() is not None:
self.fewshot_rnd = (
random.Random()
) # setting with no seed, to be overridden at a later time
config_sampler: Union[str, Callable] = (
self.config.fewshot_config.get("sampler", "default")
if self.config.fewshot_config
else "default"
)
if isinstance(config_sampler, str):
self.sampler = samplers.get_sampler(config_sampler)(
list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
)
elif callable(config_sampler) and issubclass(
config_sampler, samplers.ContextSampler
):
self.sampler = config_sampler(
docs=list(self.fewshot_docs()), task=self, rnd=self.fewshot_rnd
)
else:
raise TypeError(
f"fewshot_config.sampler should be a string or callable of ContextSampler type, "
f"not {type(config_sampler)}"
)
self.task_docs = self.eval_docs
# Test One Doc
self.features = list(self.task_docs.features.keys())
self.multiple_input = 0
self.multiple_target = 0
test_doc = self.task_docs[0]
test_text = self.doc_to_text(test_doc)
test_target = self.doc_to_target(test_doc)
self._filters = self.config.get_filters
if self.config.doc_to_choice is not None:
test_choice = self.doc_to_choice(test_doc)
if not isinstance(test_choice, list):
eval_logger.error("doc_to_choice must return list")
else:
num_choice = len(test_choice)
# if self.config.use_prompt is not None:
# eval_logger.info(f"loading prompt {self.config.use_prompt}")
# self.prompt = get_prompt(
# self.config.use_prompt, self.DATASET_PATH, self.DATASET_NAME
# )
# else:
# self.prompt = None
if isinstance(test_text, int):
eval_logger.debug(
"doc_to_text returned an int. Assuming multiple inputs."
)
self.multiple_input = num_choice
else:
test_choice = None
if isinstance(test_target, list):
eval_logger.debug(
"doc_to_target returned a list. Assuming multiple targets."
if (
self.config.fewshot_cfg.num_fewshot() > 0
and self.fewshot_docs() is not None
):
self.fewshot_rnd = random.Random()
self.sampler = self.config.fewshot_cfg.init_sampler(
list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
)
self.multiple_target = len(test_target)
else:
if (isinstance(test_target, int)) and (test_choice is not None):
test_target = test_choice[test_target]
else:
test_target = str(test_target)
self.task_docs = self.eval_docs
if test_choice is not None:
check_choices = test_choice
else:
check_choices = [test_target]
if self.config.doc_to_choice is not None:
for choice in check_choices:
choice_has_whitespace = True if choice[0].isspace() else False
delimiter_has_whitespace = (
True
if self.config.target_delimiter.rstrip()
!= self.config.target_delimiter
else False
)
# for name, fn in self.config._fn.items():
# if hasattr(self, name):
# setattr(
# self,
# name,
# types.MethodType(
# lambda self, *args, _fn=fn, **kwargs: _fn(*args, **kwargs),
# self,
# ),
# )
if delimiter_has_whitespace and choice_has_whitespace:
eval_logger.debug(
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace'
)
elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
eval_logger.debug(
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
)
self.runtime_checks(self.task_docs[0])
def download(
self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
self, dataset_kwargs:dict[str, Any] | None = None, **kwargs
) -> None:
from packaging.version import parse as vparse
self.config.dataset_kwargs, self.config.metadata = (
self.config.dataset_kwargs or {},
self.config.metadata or {},
)
if dataset_kwargs and vparse(datasets.__version__) >= vparse("4.0.0"):
dataset_kwargs.pop("trust_remote_code", None)
if isinstance(self.config.custom_dataset, Callable):
if isinstance(df := self.config.custom_dataset, Callable):
eval_logger.warning(
f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager."
+ "\nFor example --metadata='{\"max_seq_lengths\":[4096, 8192]}'. For details see task Readme."
)
self.dataset = self.config.custom_dataset(
**(self.config.metadata or {}), **(self.config.dataset_kwargs or {})
)
self.dataset = df(**(self.config.dataset_kwargs | self.config.metadata))
else:
assert self.config.dataset_path is not None, (
"dataset_path must be set in TaskConfig"
)
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
**dataset_kwargs if dataset_kwargs is not None else {},
path=self.config.dataset_path,
name=self.config.dataset_name,
**self.config.dataset_kwargs,
)
@cached_property
def has_training_docs(self) -> bool:
if self.config.training_split is not None:
return True
else:
return False
return self.config.training_split is not None
@cached_property
def has_validation_docs(self) -> bool:
if self.config.validation_split is not None:
return True
else:
return False
return self.config.validation_split is not None
@cached_property
def has_test_docs(self) -> bool:
if self.config.test_split is not None:
return True
else:
return False
return self.config.test_split is not None
def training_docs(self) -> datasets.Dataset:
if self.has_training_docs():
def training_docs(self) -> DataSet | None:
if self.has_training_docs:
if self.config.process_docs is not None:
return self.config.process_docs(
self.dataset[self.config.training_split]
)
return self.dataset[self.config.training_split]
def validation_docs(self) -> datasets.Dataset:
if self.has_validation_docs():
def validation_docs(self) -> DataSet | None:
if self.has_validation_docs:
if self.config.process_docs is not None:
return self.config.process_docs(
self.dataset[self.config.validation_split]
)
return self.dataset[self.config.validation_split]
def test_docs(self) -> datasets.Dataset:
if self.has_test_docs():
def test_docs(self) -> DataSet | None:
if self.has_test_docs:
if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.test_split])
return self.dataset[self.config.test_split]
def fewshot_docs(self):
if self.config.fewshot_split is not None:
if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.fewshot_split])
return self.dataset[self.config.fewshot_split]
elif (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("samples", None) is not None
docs = self.config.fewshot_cfg.get_docs(self.dataset)
if docs is not None:
return docs
# Fallback to parent implementation
if (
(_num_fewshot := self.config.num_fewshot)
and isinstance(_num_fewshot, int)
and _num_fewshot > 0
):
if isinstance(self.config.fewshot_config["samples"], list):
return self.config.fewshot_config["samples"]
elif callable(self.config.fewshot_config["samples"]):
return self.config.fewshot_config["samples"]()
else:
raise Exception(
"`fewshot_config['samples']` was incorrectly defined in the configuration. It should be either a list of samples as a dict, or function returning this list."
)
else:
if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
eval_logger.warning(
f"[Task: {self.config.task}] "
"num_fewshot > 0 but fewshot_split is None. "
"using preconfigured rule."
)
return super().fewshot_docs()
eval_logger.warning(
f"[Task: {self.config.task}] "
"num_fewshot > 0 but no fewshot source configured. "
"Using preconfigured rule."
)
return super().fewshot_docs()
@staticmethod
def append_target_question(
labeled_examples: List[Dict[str, str]],
labeled_examples: list[dict[str, str]],
question: str,
fewshot_as_multiturn: bool = False,
gen_prefix: Optional[str] = None,
gen_prefix: str | None = None,
) -> None:
"""Adds a target question to the labeled examples list.
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
......@@ -1084,12 +777,12 @@ class ConfigurableTask(Task):
self,
doc: dict,
num_fewshot: int,
system_instruction: Optional[str] = None,
system_instruction: str | None = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
gen_prefix: Optional[str] = None,
) -> Union[str, List[str]]:
chat_template: Callable | None = None,
gen_prefix: str | None = None,
) -> str | list[str] | None:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
......@@ -1110,10 +803,7 @@ class ConfigurableTask(Task):
:returns: str
The fewshot context.
"""
if apply_chat_template:
labeled_examples = []
else:
labeled_examples = ""
labeled_examples = [] if apply_chat_template else ""
# get task description
if description := self.config.description:
......@@ -1183,7 +873,7 @@ class ConfigurableTask(Task):
labeled_examples_list.append(
chat_template(
chat,
add_generation_prompt=False if gen_prefix else True,
add_generation_prompt=not gen_prefix,
)
)
return labeled_examples_list
......@@ -1207,7 +897,7 @@ class ConfigurableTask(Task):
# return lm.apply_chat_template(labeled_examples)
return chat_template(
labeled_examples,
add_generation_prompt=False if gen_prefix else True,
add_generation_prompt=not gen_prefix,
)
else:
prefix = (
......@@ -1228,13 +918,15 @@ class ConfigurableTask(Task):
else:
return labeled_examples + str(example) + prefix
def apply_filters(self) -> Optional[List[Instance]]:
def apply_filters(self) -> list[Instance] | None:
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"):
if hasattr(self, "_filters") and self._instances:
for f in self._filters:
f.apply(self._instances)
f.ensemble.apply(self._instances)
else:
eval_logger.warning("No filter defined, passing through instances")
eval_logger.warning(
"No filter defined or instances found. Passing through instances"
)
return self._instances
def should_decontaminate(self):
......@@ -1268,115 +960,167 @@ class ConfigurableTask(Task):
"""
return doc
def doc_to_text(self, doc, doc_to_text=None):
if self.prompt is not None:
doc_to_text = self.prompt
elif doc_to_text is not None:
doc_to_text = doc_to_text
else:
doc_to_text = self.config.doc_to_text
@overload
def doc_to_text(self, doc: dict, doc_to_text: None = None) -> str | int: ...
if isinstance(doc_to_text, int):
return doc_to_text
@overload
def doc_to_text(self, doc: dict, doc_to_text: int) -> int: ...
@overload
def doc_to_text(self, doc: dict, doc_to_text: str) -> str: ...
@overload
def doc_to_text(self, doc: dict, doc_to_text: Callable[..., str]) -> str: ...
def doc_to_text(
self, doc: dict, doc_to_text: int | str | Callable[..., str] | None = None
) -> str | int:
# if self.prompt is not None:
# doc_to_text = self.prompt
doc_to_text = doc_to_text or self.config.doc_to_text
if callable(doc_to_text):
return doc_to_text(doc)
if doc_to_text in doc:
return doc[doc_to_text]
elif isinstance(doc_to_text, str):
if doc_to_text in self.features:
# if self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_text]]
# else:
return doc[doc_to_text]
text_string = utils.apply_template(doc_to_text, doc)
if text_string.isdigit() and self.config.doc_to_choice is not None:
return ast.literal_eval(text_string)
else:
text_string = utils.apply_template(doc_to_text, doc)
if text_string.isdigit() and self._config.doc_to_choice is not None:
return ast.literal_eval(text_string)
else:
return text_string
elif callable(doc_to_text):
return doc_to_text(doc)
return text_string
elif isinstance(doc_to_text, int):
return doc_to_text
# Used when applying a Promptsource template
elif hasattr(doc_to_text, "apply"):
applied_prompt = doc_to_text.apply(doc)
if len(applied_prompt) == 2:
return applied_prompt[0]
else:
eval_logger.warning("Applied prompt returns empty string")
return self.config.fewshot_delimiter
# elif hasattr(doc_to_text, "apply"):
# applied_prompt = doc_to_text.apply(doc)
# if len(applied_prompt) == 2:
# return applied_prompt[0]
# else:
# eval_logger.warning("Applied prompt returns empty string")
# return self.config.fewshot_delimiter
else:
print(type(doc_to_text))
raise TypeError
def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]:
if self.prompt is not None:
doc_to_target = self.prompt
elif doc_to_target is not None:
doc_to_target = doc_to_target
else:
doc_to_target = self.config.doc_to_target
if isinstance(doc_to_target, int):
return doc_to_target
@overload
def doc_to_target(
self, doc: dict, doc_to_target: None = None
) -> int | str | list[int]: ...
@overload
def doc_to_target(self, doc: dict, doc_to_target: int) -> int: ...
@overload
def doc_to_target(self, doc: dict, doc_to_target: str) -> int | str | list[int]: ...
@overload
def doc_to_target(self, doc: dict, doc_to_target: list) -> list[int]: ...
@overload
def doc_to_target(
self, doc: dict, doc_to_target: Callable[..., int | str | list[int]]
) -> int | str | list[int]: ...
def doc_to_target(self, doc: dict, doc_to_target=None) -> int | str | list[int]:
# if self.prompt is not None:
# doc_to_target = self.prompt
doc_to_target = doc_to_target or self.config.doc_to_target
if callable(doc_to_target):
doc_to_target(doc)
if doc_to_target in doc:
return doc[doc_to_target]
elif isinstance(doc_to_target, str):
if doc_to_target in self.features:
# if self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_target]]
# else:
return doc[doc_to_target]
target_string = utils.apply_template(doc_to_target, doc)
if target_string.isdigit() and self.config.doc_to_choice is not None:
return ast.literal_eval(target_string)
# elif (
# len(target_string) >= 2
# and (target_string[0] == "[")
# and (target_string[-1] == "]")
# ):
# try:
# return ast.literal_eval(target_string)
# except (SyntaxError, ValueError):
# return target_string
else:
target_string = utils.apply_template(doc_to_target, doc)
if target_string.isdigit() and self._config.doc_to_choice is not None:
return ast.literal_eval(target_string)
elif (
len(target_string) >= 2
and (target_string[0] == "[")
and (target_string[-1] == "]")
):
try:
return ast.literal_eval(target_string)
except (SyntaxError, ValueError):
return target_string
else:
return target_string
elif isinstance(doc_to_target, list):
return target_string
elif isinstance(doc_to_target, (int, list)):
return doc_to_target
elif callable(doc_to_target):
return doc_to_target(doc)
# Used when applying a Promptsource template
elif hasattr(doc_to_target, "apply"):
applied_prompt = doc_to_target.apply(doc)
if len(applied_prompt) == 2:
return applied_prompt[1]
else:
eval_logger.warning("Applied prompt returns empty string")
return self.config.fewshot_delimiter
# elif isinstance(doc_to_target, list):
# return doc_to_target
# elif callable(doc_to_target):
# return doc_to_target(doc)
# # Used when applying a Promptsource template
# elif hasattr(doc_to_target, "apply"):
# applied_prompt = doc_to_target.apply(doc)
# if len(applied_prompt) == 2:
# return applied_prompt[1]
# else:
# eval_logger.warning("Applied prompt returns empty string")
# return self.config.fewshot_delimiter
else:
raise TypeError
def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]:
if self.prompt is not None:
doc_to_choice = self.prompt
elif doc_to_choice is not None:
@overload
def doc_to_choice(self, doc: dict, doc_to_choice: None = None) -> list[str]: ...
@overload
def doc_to_choice(self, doc: dict, doc_to_choice: str) -> list[str]: ...
@overload
def doc_to_choice(self, doc: dict, doc_to_choice: list) -> list[str]: ...
@overload
def doc_to_choice(self, doc: dict, doc_to_choice: dict) -> list[str]: ...
@overload
def doc_to_choice(
self, doc: dict, doc_to_choice: Callable[..., list[str]]
) -> list[str]: ...
def doc_to_choice(
self,
doc: dict,
doc_to_choice: str | list | dict | Callable[..., list[str]] | None = None,
) -> list[str]:
# if self.prompt is not None:
# doc_to_choice = self.prompt
if doc_to_choice is not None:
doc_to_choice = doc_to_choice
elif self.config.doc_to_choice is None:
eval_logger.error("doc_to_choice was called but not set in config")
doc_to_choice = None
else:
doc_to_choice = self.config.doc_to_choice
if isinstance(doc_to_choice, str):
if doc_to_choice in self.features:
if doc_to_choice in doc:
return doc[doc_to_choice]
else:
return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
elif isinstance(doc_to_choice, list):
return doc_to_choice
elif isinstance(doc_to_choice, dict):
return list(doc_to_choice.values())
elif callable(doc_to_choice):
return doc_to_choice(doc)
elif hasattr(doc_to_choice, "get_answer_choices_list"):
return doc_to_choice.get_answer_choices_list(doc)
# elif isinstance(doc_to_choice, dict):
# return list(doc_to_choice.values())
# elif hasattr(doc_to_choice, "get_answer_choices_list"):
# return doc_to_choice.get_answer_choices_list(doc)
else:
raise TypeError
def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]:
@overload
def doc_to_image(self, doc: dict, doc_to_image: None = None) -> None: ...
@overload
def doc_to_image(self, doc: dict, doc_to_image: list) -> list: ...
@overload
def doc_to_image(self, doc: dict, doc_to_image: str) -> int | str | None: ...
@overload
def doc_to_image(self, doc: dict, doc_to_image: Callable[..., Any]) -> Any: ...
def doc_to_image(self, doc: dict, doc_to_image=None) -> int | str | list | None:
if doc_to_image is not None:
doc_to_image = doc_to_image
elif self.config.doc_to_image is not None:
......@@ -1399,7 +1143,19 @@ class ConfigurableTask(Task):
else:
return None
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list]:
@overload
def doc_to_audio(self, doc: Any, doc_to_audio: None = None) -> None: ...
@overload
def doc_to_audio(self, doc: Any, doc_to_audio: list) -> list: ...
@overload
def doc_to_audio(self, doc: Any, doc_to_audio: str) -> int | str | None: ...
@overload
def doc_to_audio(self, doc: Any, doc_to_audio: Callable[..., Any]) -> Any: ...
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> int | str | list | None:
if doc_to_audio is not None:
doc_to_audio = doc_to_audio
elif self.config.doc_to_audio is not None:
......@@ -1422,9 +1178,9 @@ class ConfigurableTask(Task):
else:
return None
def doc_to_prefix(self, doc):
def doc_to_prefix(self, doc: dict) -> str | None:
if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in self.features:
if gen_prefix in doc:
return doc[gen_prefix]
else:
return utils.apply_template(gen_prefix, doc)
......@@ -1432,7 +1188,7 @@ class ConfigurableTask(Task):
def construct_requests(
self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]:
) -> list[Instance] | Instance:
apply_chat_template = kwargs.pop("apply_chat_template", False)
chat_template: Callable | None = kwargs.pop("chat_template", None)
......@@ -1469,7 +1225,7 @@ class ConfigurableTask(Task):
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys():
if "acc_mutual_info" in [m.metric_name for m in self.config._metric_list]:
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
......@@ -1531,12 +1287,11 @@ class ConfigurableTask(Task):
**kwargs,
)
def process_results(self, doc, results):
def process_results(self, doc: dict, results: list) -> dict[str, Any]:
if callable(self.config.process_results):
return self.config.process_results(doc, results)
result_dict = {}
use_metric = list(self._metric_fn_list.keys())
use_metric = list(m.metric_name for m in self.config._metric_list)
if self.OUTPUT_TYPE == "loglikelihood":
results = results[0]
ll, is_greedy = results
......@@ -1545,9 +1300,12 @@ class ConfigurableTask(Task):
**({"acc": int(is_greedy)} if "acc" in use_metric else {}),
}
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
(loglikelihood,) = results
_words = self.count_words(self.doc_to_target(doc))
_bytes = self.count_bytes(self.doc_to_target(doc))
(loglikelihood, *_) = results
assert isinstance(_target := self.doc_to_target(doc), str), (
"Require target to be a string for loglikelihood_rolling"
)
_words = self.count_words(_target)
_bytes = self.count_bytes(_target)
return {
**(
{"word_perplexity": (loglikelihood, _words)}
......@@ -1568,14 +1326,11 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results)
# retrieve choices in List[str] form, to compute choice lengths, etc.
# retrieve choices in list[str] form, to compute choice lengths, etc.
choices = self.doc_to_choice(doc)
completion_len = np.array([float(len(i)) for i in choices])
if (
2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_fn_list.keys()
):
if 2 * len(choices) == len(lls) and "acc_mutual_info" in use_metric:
# then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods
# as we extend the args list with unconditional ("", continuation) pairs
......@@ -1584,6 +1339,8 @@ class ConfigurableTask(Task):
raise ValueError
# and this stores our "regular" conditional loglikelihoods
lls = lls[: len(choices)]
else:
lls_unconditional = None
pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len)
......@@ -1593,19 +1350,7 @@ class ConfigurableTask(Task):
else:
gold = self.doc_to_target(doc)
gold_index_error = False
if isinstance(gold, list):
gold = [i if i < len(choices) else -100 for i in gold]
if -100 in gold:
gold_index_error = True
else:
if isinstance(gold, int):
gold = gold if gold < len(choices) else -100
elif isinstance(gold, str):
gold = choices.index(gold) if gold in choices else -100
if gold == -100:
gold_index_error = True
gold, gold_index_error = check_gold_index_error(choices, gold)
if gold_index_error:
eval_logger.warning(
......@@ -1616,7 +1361,7 @@ class ConfigurableTask(Task):
if self.multiple_target:
acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0
exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
exact_match = int(any(is_greedy[i] if i != -100 else 0 for i in gold))
else:
acc = 1.0 if pred == gold else 0.0
acc_norm = 1.0 if pred_norm == gold else 0.0
......@@ -1641,6 +1386,9 @@ class ConfigurableTask(Task):
}
if "acc_mutual_info" in use_metric:
assert lls_unconditional is not None, (
"lls_unconditional should not be None if acc_mutual_info is in use_metric"
)
lls_mutual_info = [
ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
]
......@@ -1650,77 +1398,22 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "generate_until":
gold = self.doc_to_target(doc)
result = results[0]
if self.config.doc_to_choice is not None:
# If you set doc_to_choice,
# it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc)
gold = choices[gold]
# we expect multiple_targets to be a list.
elif self.multiple_target:
gold = list(gold)
# TODO: handle this better
elif type(gold) is not type(result) and not (
"bypass" in self._metric_fn_list.keys() or isinstance(result, list)
):
# cast gold to the same type as result
gold = type(result)(gold)
for metric in self._metric_fn_list.keys():
if self.multiple_target:
# in the case where we have multiple targets,
# return true if any are true
# TODO: this may break for multipLe_target, non zero-or-1 metrics
scores = []
if not isinstance(gold, list):
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
# print(gold)
gold = [gold]
if metric == "exact_match":
result = [result for _ in range(len(gold))]
scores = self._metric_fn_list[metric](
references=gold,
predictions=result,
**self._metric_fn_kwargs[metric],
)[metric]
result_score = 1.0 if scores > 0.0 else 0.0
else:
for gold_option in gold:
try:
result_score = self._metric_fn_list[metric](
references=[gold_option],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except (
TypeError
): # TODO: this is hacky and I don't want to do it
result_score = self._metric_fn_list[metric](
[gold_option, result]
)
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric]
scores.append(result_score)
if any(scores):
result_score = 1.0
else:
result_score = 0.0
else:
try:
result_score = self._metric_fn_list[metric](
references=[gold],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result])
for metric in self.config._metric_list:
try:
result_score = metric.fn(
references=[gold] if not isinstance(gold, list) else gold,
predictions=[result],
**metric.kwargs,
)
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = metric.fn([gold, result])
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
# This allows for multiple metrics to be returned from the same function
for k, v in result_score.items():
result_dict[k] = v
else:
result_dict[metric] = result_score
result_dict[metric.name] = result_score
else:
raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
......@@ -1730,18 +1423,75 @@ class ConfigurableTask(Task):
return result_dict
def aggregation(self) -> dict:
return self._aggregation_list
return {k.name: k.aggregation_fn for k in self.config._metric_list}
def higher_is_better(self) -> dict:
return self._higher_is_better
return {k.name: k.higher_is_better for k in self.config._metric_list}
def get_config(self, key: str) -> Any:
return getattr(self._config, key, None)
@property
def task_name(self) -> Any:
def task_name(self) -> str | None:
return getattr(self.config, "task", None)
def runtime_checks(self, test_doc):
# Test One Doc
self.features: list[str] = list(self.task_docs.features.keys())
self.multiple_target = 0
self.multiple_input = 0
test_text = self.doc_to_text(test_doc)
test_target = self.doc_to_target(test_doc)
if self.config.doc_to_choice is not None:
test_choice = self.doc_to_choice(test_doc)
if not isinstance(test_choice, list):
eval_logger.error("doc_to_choice must return list")
else:
num_choice = len(test_choice)
if isinstance(test_text, int):
eval_logger.debug(
"doc_to_text returned an int. Assuming multiple inputs."
)
if isinstance(test_text, int):
eval_logger.debug(
"doc_to_text returned an int. Assuming multiple inputs."
)
self.multiple_input = num_choice
else:
test_choice = None
if isinstance(test_target, list):
eval_logger.debug(
"doc_to_target returned a list. Assuming multiple targets."
)
self.multiple_target = len(test_target)
else:
if (isinstance(test_target, int)) and (test_choice is not None):
test_target = test_choice[test_target]
else:
test_target = str(test_target)
check_choices = test_choice if test_choice is not None else [test_target]
if self.config.doc_to_choice is not None:
for choice in check_choices:
choice_has_whitespace = choice[0].isspace()
delimiter_has_whitespace = (
self.config.target_delimiter.rstrip()
!= self.config.target_delimiter
)
if delimiter_has_whitespace and choice_has_whitespace:
eval_logger.debug(
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace'
)
elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
eval_logger.debug(
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
)
def __repr__(self):
return (
f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
......@@ -1757,7 +1507,7 @@ class MultipleChoiceTask(Task):
def doc_to_target(self, doc: dict) -> str:
return " " + doc["choices"][doc["gold"]]
def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]:
def construct_requests(self, doc: dict, ctx: str, **kwargs) -> list[Instance]:
# TODO: add mutual info here?
return [
Instance(
......@@ -1770,7 +1520,7 @@ class MultipleChoiceTask(Task):
for i, choice in enumerate(doc["choices"])
]
def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict:
def process_results(self, doc: dict, results: Iterable[tuple[float, bool]]) -> dict:
results = [
res[0] for res in results
] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
......@@ -1806,7 +1556,7 @@ class PerplexityTask(Task):
def has_training_docs(self) -> bool:
return False
def fewshot_examples(self, k: int, rnd) -> List:
def fewshot_examples(self, k: int, rnd) -> list:
if k != 0:
raise ValueError(
"The number of fewshot examples must be 0 for perplexity tasks."
......@@ -1837,7 +1587,7 @@ class PerplexityTask(Task):
def doc_to_target(self, doc):
return doc
def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs):
def construct_requests(self, doc: dict, ctx: str | None, **kwargs):
if bool(ctx):
raise ValueError
......@@ -1849,7 +1599,7 @@ class PerplexityTask(Task):
**kwargs,
)
def process_results(self, doc: dict, results: Tuple[float]) -> dict:
def process_results(self, doc: dict, results: tuple[float]) -> dict:
(loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(self.doc_to_target(doc))
......
from __future__ import annotations
def check_gold_index_error(
choices: list[int] | list[str], gold: list[int] | int | str
) -> tuple[int | list[int], bool]:
gold_index_error = False
if isinstance(gold, list):
gold = [i if i < len(choices) else -100 for i in gold]
if -100 in gold:
gold_index_error = True
return gold, gold_index_error
else:
if isinstance(gold, int):
gold = gold if gold < len(choices) else -100
elif isinstance(gold, str):
gold = choices.index(gold) if gold in choices else -100
if gold == -100:
gold_index_error = True
return gold, gold_index_error
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment