Commit cb8889cc authored by lintangsutawika's avatar lintangsutawika
Browse files

merged with latest update from main

parents ec05e561 74119471
name: Publish Python distribution to PyPI
on:
push:
tags:
- '*'
jobs:
build:
name: Build distribution
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.x"
- name: Install pypa/build
run: >-
python3 -m
pip install
build
--user
- name: Build a binary wheel and a source tarball
run: python3 -m build
- name: Store the distribution packages
uses: actions/upload-artifact@v3
with:
name: python-package-distributions
path: dist/
publish-to-pypi:
name: >-
Publish Python distribution to PyPI
if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes
needs:
- build
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/p/lm_eval
permissions:
id-token: write # IMPORTANT: mandatory for trusted publishing
steps:
- name: Download all the dists
uses: actions/download-artifact@v3
with:
name: python-package-distributions
path: dist/
- name: Publish distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
publish-to-testpypi:
name: Publish Python distribution to TestPyPI
needs:
- build
runs-on: ubuntu-latest
environment:
name: testpypi
url: https://test.pypi.org/p/lm_eval
permissions:
id-token: write # IMPORTANT: mandatory for trusted publishing
steps:
- name: Download all the dists
uses: actions/download-artifact@v3
with:
name: python-package-distributions
path: dist/
- name: Publish distribution to TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/
......@@ -56,7 +56,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e '.[dev,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e '.[dev,anthropic,sentencepiece,optimum]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
......
......@@ -45,26 +45,7 @@ git clone https://github.com/EleutherAI/lm-evaluation-harness
cd lm-evaluation-harness
pip install -e .
```
We also provide a number of optional dependencies for extended functionality. Extras can be installed via `pip install -e ".[NAME]"`
| Name | Use |
|---------------|---------------------------------------|
| anthropic | For using Anthropic's models |
| dev | For linting PRs and contributions |
| gptq | For loading models with GPTQ |
| ifeval | For running the IFEval task |
| mamba | For loading Mamba SSM models |
| math | For running math task answer checking |
| multilingual | For multilingual tokenizers |
| openai | For using OpenAI's models |
| promptsource | For using PromptSource prompts |
| sentencepiece | For using the sentencepiece tokenizer |
| testing | For running library test suite |
| vllm | For loading models with vLLM |
| zeno | For visualizing results with Zeno |
|---------------|---------------------------------------|
| all | Loads all extras (not recommended) |
We also provide a number of optional dependencies for extended functionality. A detailed table is available at the end of this document.
## Basic Usage
......@@ -145,6 +126,9 @@ For more advanced users or even larger models, we allow for the following argume
These two options (`accelerate launch` and `parallelize=True`) are mutually exclusive.
**Note: we do not currently support multi-node evaluations natively, and advise using either an externally hosted server to run inference requests against, or creating a custom integration with your distributed framework [as is done for the GPT-NeoX library](https://github.com/EleutherAI/gpt-neox/blob/main/eval_tasks/eval_adapter.py).**
### Tensor + Data Parallel and Optimized Inference with `vLLM`
We also support vLLM for faster inference on [supported model types](https://docs.vllm.ai/en/latest/models/supported_models.html), especially faster when splitting a model across multiple GPUs. For single-GPU or multi-GPU — tensor parallel, data parallel, or a combination of both — inference, for example:
......@@ -189,10 +173,10 @@ Note that for externally hosted models, configs such as `--device` and `--batch_
| [Llama.cpp](https://github.com/ggerganov/llama.cpp) (via [llama-cpp-python](https://github.com/abetlen/llama-cpp-python)) | :heavy_check_mark: | `gguf`, `ggml` | [All models supported by llama.cpp](https://github.com/ggerganov/llama.cpp) | `generate_until`, `loglikelihood`, (perplexity evaluation not yet implemented) |
| vLLM | :heavy_check_mark: | `vllm` | [Most HF Causal Language Models](https://docs.vllm.ai/en/latest/models/supported_models.html) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| Mamba | :heavy_check_mark: | `mamba_ssm` | [Mamba architecture Language Models via the `mamba_ssm` package](https://huggingface.co/state-spaces) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| Your local inference server! | :heavy_check_mark: | `local-completions` or `local-chat-completions` (using `openai-chat-completions` model type) | Any server address that accepts GET requests using HF models and mirror's OpenAI's ChatCompletions interface | `generate_until` | | ... |
| `local-completions` (using `openai-completions` model type) | Any server address that accepts GET requests using HF models and mirror's OpenAI's Completions interface | `generate_until` | | ... |
| Huggingface Optimum (Causal LMs) | ✔️ | `openvino` | Any decoder-only AutoModelForCausalLM converted with Huggingface Optimum into OpenVINO™ Intermediate Representation (IR) format | `generate_until`, `loglikelihood`, `loglikelihood_rolling` | ... |
| Your local inference server! | :heavy_check_mark: | `local-completions` or `local-chat-completions` (using `openai-chat-completions` model type) | Any server address that accepts GET requests using HF models and mirror's OpenAI's Completions or ChatCompletions interface | `generate_until` | | ... |
Models which do not supply logits or logprobs can be used with tasks of type `generate_until` only, while models that are local or APIs that supply logprobs/logits can be run on all task types: `generate_until`, `loglikelihood`, `loglikelihood_rolling`, and `multiple_choice`.
Models which do not supply logits or logprobs can be used with tasks of type `generate_until` only, while local models, or APIs that supply logprobs/logits of their prompts, can be run on all task types: `generate_until`, `loglikelihood`, `loglikelihood_rolling`, and `multiple_choice`.
For more information on the different task `output_types` and model request types, see [our documentation](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/model_guide.md#interface).
......@@ -203,6 +187,8 @@ A number of other libraries contain scripts for calling the eval harness through
To create your own custom integration you can follow instructions from [this tutorial](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage).
### Additional Features
> [!Note]
> For tasks unsuitable for direct evaluation — either due risks associated with executing untrusted code or complexities in the evaluation process — the `--predict_only` flag is available to obtain decoded generations for post-hoc evaluation.
If you have a Metal compatible Mac, you can run the eval harness using the MPS back-end by replacing `--device cuda:0` with `--device mps` (requires PyTorch version 2.1 or higher).
......@@ -252,6 +238,9 @@ Additionally, one can provide a directory with `--use_cache` to cache the result
For a full list of supported arguments, check out the [interface](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md) guide in our documentation!
> [!Tip]
> Running lm-evaluation-harness as an external library and can't find (almost) any tasks available? Run `lm_eval.tasks.initialize_tasks()` to load the library's stock tasks before calling `lm_eval.evaluate()` or `lm_eval.simple_evaluate()` !
## Visualizing Results
You can use [Zeno](https://zenoml.com) to visualize the results of your eval harness runs.
......@@ -315,6 +304,28 @@ We try to prioritize agreement with the procedures used by other groups to decre
The best way to get support is to open an issue on this repo or join the [EleutherAI Discord server](https://discord.gg/eleutherai). The `#lm-thunderdome` channel is dedicated to developing this project and the `#release-discussion` channel is for receiving support for our releases. If you've used the library and have had a positive (or negative) experience, we'd love to hear from you!
## Optional Extras
Extras dependencies can be installed via `pip install -e ".[NAME]"`
| Name | Use |
|---------------|---------------------------------------|
| anthropic | For using Anthropic's models |
| dev | For linting PRs and contributions |
| gptq | For loading models with GPTQ |
| ifeval | For running the IFEval task |
| mamba | For loading Mamba SSM models |
| math | For running math task answer checking |
| multilingual | For multilingual tokenizers |
| openai | For using OpenAI's models |
| optimum | For running Intel OpenVINO models |
| promptsource | For using PromptSource prompts |
| sentencepiece | For using the sentencepiece tokenizer |
| testing | For running library test suite |
| vllm | For loading models with vLLM |
| zeno | For visualizing results with Zeno |
|---------------|---------------------------------------|
| all | Loads all extras (not recommended) |
## Cite as
```
......
# Contributing to LM Evaluation Harness
Welcome and thank you for your interest in the LM Evaluation Harness! We welcome contributions and feedback and appreciate your time spent with our library, and hope you find it useful!
We intend LM Evaluation Harness to be a broadly useful and
## Important Resources
There are several places information about LM Evaluation Harness is located:
- Our [documentation pages](https://github.com/EleutherAI/lm-evaluation-harness/tree/main/docs)
- We occasionally use [GitHub Milestones](https://github.com/EleutherAI/lm-evaluation-harness/milestones) to track progress toward specific near-term version releases.
- We maintain a [Project Board](https://github.com/orgs/EleutherAI/projects/25) for tracking current work items and PRs, and for future roadmap items or feature requests.
- Further discussion and support conversations are located in the #lm-thunderdome channel of the [EleutherAI discord](discord.gg/eleutherai).
## Code Style
LM Evaluation Harness uses [ruff](https://github.com/astral-sh/ruff) for linting via [pre-commit](https://pre-commit.com/).
You can install linters and dev tools via
```pip install lm_eval[dev]```
Then, run
```pre-commit install```
in order to ensure linters and other checks will be run upon committing.
## Testing
We use [pytest](https://docs.pytest.org/en/latest/) for running unit tests. All library unit tests can be run via:
```
python -m pytest --ignore=tests/tests_master --ignore=tests/extra
```
## Contributor License Agreement
We ask that new contributors agree to a Contributor License Agreement affirming that EleutherAI has the rights to use your contribution to our library.
First-time pull requests will have a reply added by @CLAassistant containing instructions for how to confirm this, and we require it before merging your PR.
## Contribution Best Practices
We recommend a few best practices to make your contributions or reported errors easier to assist with.
**For Pull Requests:**
- PRs should be titled descriptively, and be opened with a brief description of the scope and intent of the new contribution.
- New features should have appropriate documentation added alongside them.
- Aim for code maintainability, and minimize code copying.
- If opening a task, try to share test results on the task using a publicly-available model, and if any public results are available on the task, compare to them.
**For Feature Requests:**
- Provide a short paragraph's worth of description. What is the feature you are requesting? What is its motivation, and an example use case of it? How does this differ from what is currently supported?
**For Bug Reports**:
- Provide a short description of the bug.
- Provide a *reproducible example*--what is the command you run with our library that results in this error? Have you tried any other steps to resolve it?
- Provide a *full error traceback* of the error that occurs, if applicable. A one-line error message or small screenshot snippet is unhelpful without the surrounding context.
- Note what version of the codebase you are using, and any specifics of your environment and setup that may be relevant.
**For Requesting New Tasks**:
- Provide a 1-2 sentence description of what the task is and what it evaluates.
- Provide a link to the paper introducing the task.
- Provide a link to where the dataset can be found.
- Provide a link to a paper containing results on an open-source model on the task, for use in comparisons and implementation validation.
- If applicable, link to any codebase that has implemented the task (especially the original publication's codebase, if existent).
## How Can I Get Involved?
To quickly get started, we maintain a list of good first issues, which can be found [on our project board](https://github.com/orgs/EleutherAI/projects/25/views/8) or by [filtering GH Issues](https://github.com/EleutherAI/lm-evaluation-harness/issues?q=is%3Aopen+label%3A%22good+first+issue%22+label%3A%22help+wanted%22). These are typically smaller code changes or self-contained features which can be added without extensive familiarity with library internals, and we recommend new contributors consider taking a stab at one of these first if they are feeling uncertain where to begin.
There are a number of distinct ways to contribute to LM Evaluation Harness, and all are extremely helpful! A sampling of ways to contribute include:
- **Implementing and verifying new evaluation tasks**: Is there a task you'd like to see LM Evaluation Harness support? Consider opening an issue requesting it, or helping add it! Verifying and cross-checking task implementations with their original versions is also a very valuable form of assistance in ensuring standardized evaluation.
- **Improving documentation** - Improvements to the documentation, or noting pain points / gaps in documentation, are helpful in order for us to improve the user experience of the library and clarity + coverage of documentation.
- **Testing and devops** - We are very grateful for any assistance in adding tests for the library that can be run for new PRs, and other devops workflows.
- **Adding new modeling / inference library integrations** - We hope to support a broad range of commonly-used inference libraries popular among the community, and welcome PRs for new integrations, so long as they are documented properly and maintainable.
- **Proposing or Contributing New Features** - We want LM Evaluation Harness to support a broad range of evaluation usecases. If you have a feature that is not currently supported but desired, feel free to open an issue describing the feature and, if applicable, how you intend to implement it. We would be happy to give feedback on the cleanest way to implement new functionalities and are happy to coordinate with interested contributors via GH discussions or via discord.
We hope that this has been helpful, and appreciate your interest in contributing! Further questions can be directed to [our Discord](discord.gg/eleutherai).
......@@ -44,6 +44,8 @@ This mode supports a number of command-line arguments, the details of which can
* `--include_path` : Accepts a path to a folder. If passed, then all YAML files containing `lm-eval`` compatible task configurations will be added to the task registry as available tasks. Used for when one is writing config files for their own task in a folder other than `lm_eval/tasks/`
* `--predict_only`: Generates the model outputs without computing metrics. Use with `--log_samples` to retrieve decoded results.
## External Library Usage
We also support using the library's external API for use within model training loops or other scripts.
......@@ -59,14 +61,25 @@ import lm_eval
my_model = initialize_my_model() # create your model (could be running finetuning with some custom modeling code)
...
lm_obj = Your_LM(model=my_model, batch_size=16) # instantiate an LM subclass that takes your initialized model and can run `Your_LM.loglikelihood()`, `Your_LM.loglikelihood_rolling()`, `Your_LM.generate_until()`
lm_eval.tasks.initialize_tasks() # register all tasks from the `lm_eval/tasks` subdirectory. Alternatively, can call `lm_eval.tasks.include_path("path/to/my/custom/task/configs")` to only register a set of tasks in a separate directory.
# instantiate an LM subclass that takes your initialized model and can run
# - `Your_LM.loglikelihood()`
# - `Your_LM.loglikelihood_rolling()`
# - `Your_LM.generate_until()`
lm_obj = Your_LM(model=my_model, batch_size=16)
# indexes all tasks from the `lm_eval/tasks` subdirectory.
# Alternatively, you can set `TaskManager(include_path="path/to/my/custom/task/configs")`
# to include a set of tasks in a separate directory.
task_manager = lm_eval.tasks.TaskManager()
# Setting `task_manager` to the one above is optional and should generally be done
# if you want to include tasks from paths other than ones in `lm_eval/tasks`.
# `simple_evaluate` will instantiate its own task_manager is the it is set to None here.
results = lm_eval.simple_evaluate( # call simple_evaluate
model=lm_obj,
tasks=["taskname1", "taskname2"],
num_fewshot=0,
task_manager=task_manager,
...
)
```
......@@ -82,18 +95,49 @@ As a brief example usage of `evaluate()`:
```python
import lm_eval
from my_tasks import MyTask1 # suppose you've defined a custom lm_eval.api.Task subclass in your own external codebase
# suppose you've defined a custom lm_eval.api.Task subclass in your own external codebase
from my_tasks import MyTask1
...
my_model = initialize_my_model() # create your model (could be running finetuning with some custom modeling code)
# create your model (could be running finetuning with some custom modeling code)
my_model = initialize_my_model()
...
lm_obj = Your_LM(model=my_model, batch_size=16) # instantiate an LM subclass that takes your initialized model and can run `Your_LM.loglikelihood()`, `Your_LM.loglikelihood_rolling()`, `Your_LM.generate_until()`
lm_eval.tasks.initialize_tasks() # register all tasks from the `lm_eval/tasks` subdirectory. Alternatively, can call `lm_eval.tasks.include_path("path/to/my/custom/task/configs")` to only register a set of tasks in a separate directory.
# instantiate an LM subclass that takes your initialized model and can run
# - `Your_LM.loglikelihood()`
# - `Your_LM.loglikelihood_rolling()`
# - `Your_LM.generate_until()`
lm_obj = Your_LM(model=my_model, batch_size=16)
# The task_manager indexes tasks including ones
# specified by the user through `include_path`
task_manager = lm_eval.tasks.TaskManager(
include_path="/path/to/custom/yaml"
)
# To get a task dict for `evaluate`
task_dict = lm_eval.tasks.get_task_dict(
[
"mmlu", # A stock task
"my_custom_task", # A custom task
{
"task": ..., # A dict that configures a task
"doc_to_text": ...,
},
MyTask1 # A task object from `lm_eval.task.Task`
],
task_manager # A task manager that allows lm_eval to
# load the task during evaluation.
# If none is provided, `get_task_dict`
# will instantiated one itself, but this
# only includes the stock tasks so users
# will need to set this if including
# custom paths is required.
)
def evaluate(
lm=lm_obj,
task_dict={"mytask1": MyTask1},
task_dict=task_dict,
...
):
```
......@@ -256,7 +256,7 @@ metric_list:
```
`aggregation` and `higher_is_better` can optionally be left out to default to the manually-set defaults if using a natively supported metric, otherwise it must be defined explicitly (for example, when using a custom metric implemented as a function).
For a full list of natively supported metrics and aggregation functions see `docs/advanced_task_guide.md`. All metrics supported in [HuggingFace Evaluate](https://github.com/huggingface/evaluate/tree/main/metrics) can also be used, and will be loaded if a given metric name is not one natively supported in `lm-eval` or `hf_evaluate` is set to `true`.
For a full list of natively supported metrics and aggregation functions see [`docs/task_guide.md`](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/task_guide.md). All metrics supported in [HuggingFace Evaluate](https://github.com/huggingface/evaluate/tree/main/metrics) can also be used, and will be loaded if a given metric name is not one natively supported in `lm-eval` or `hf_evaluate` is set to `true`.
### Optional, More Advanced Setup
......@@ -269,7 +269,7 @@ As a heuristic check:
* Do you expect to compute metrics after applying multiple such processing steps on your model outputs?
* Does your task rely on metrics that need a custom implementation?
For more detail on the task system and advanced features, see `docs/advanced_task_guide.md` . If none of the above sound like they apply to your task, it's time to continue onto checking your task performance!
For more detail on the task system and advanced features, see [`docs/task_guide.md`](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/task_guide.md) . If none of the above sound like they apply to your task, it's time to continue onto checking your task performance!
### Task name + groups (registering a task)
......@@ -294,17 +294,80 @@ This will add your task to the `group1` and `group2` groups, enabling people to
If your task is not in the `lm_eval/tasks` folder, you'll need to tell the Eval Harness where to look for YAML files.
You can do this via adding the Python snippet
You can do this via the `--include_path` argument in `__main__.py`. This command will be used to initialize the `TaskManager` object which you can also use for your custom scripts.
```python
from lm_eval.tasks import include_task_folder
include_task_folder("/path/to/yaml/parent/folder")
task_manager = TaskManager(args.verbosity, include_path=args.include_path)
```
to the top of any Python file that is run or imported when performing evaluation, such as `\_\_main\_\_.py`.
Passing `--tasks /path/to/yaml/file` is also accepted.
### Advanced Group Configs
You can make more complete group config while also tailoring parameters for individual tasks.
For example, let's build a config for evaluating MMLU and a few natural language inference tasks. For MMLU, we can write the name for the benchmark as a subtask written under `task`. You can configure the parameters such as `num_fewshot`. If the task being configured is a group such as `mmlu` or `super_glue`, the parameter set will be applied to all of the subtasks.
```yaml
group: nli_and_mmlu
task:
- group: nli_tasks
task:
- cb
- anli_r1
- rte
- task: mmlu
num_fewshot: 2
```
It's also important to note how you can basically insert a group config as a task. Here, to make a group of natural language inference tasks, you simply write like how you would normally write a group config but this time place that as part of a task list under the main group being built.
### Duplicate Tasks in Group Configs
There might be cases where you might want to evaluate prompts and how models perform over prompt variations. You can list an existing task (In the example below, `anli_r1`) which varying `doc_to_text` implementation. To differentiate from each variation, we can utilize `task_alias`. LM-Eval will recognize that there are multiple variations of the same tasks and differentiate them.
```yaml
group: flan_held_in
group_alias: Flan (Held-In)
task:
# ANLI R1
- group: anli_r1_flan
group_alias: ANLI R1
task:
- task: anli_r1
task_alias: prompt-0
include: _held_in_template_yaml
doc_to_text: "{{premise}}\n\nChoose your answer ..."
...
- task: anli_r1
task_alias: prompt-1
include: _held_in_template_yaml
doc_to_text: "{{premise}}\n\nBased on ..."
...
```
### Configuring python classes
There can occasions when yaml-based tasks cannot accommodate how a task is handled. LM-Eval supports the manually implementing tasks as was previously done before `0.4.x`. To register the task, you can simply make a yaml with the name of the task in `task` and the class object in `class` using the `!function` prefix.
```yaml
task: squadv2
class: !function task.SQuAD2
```
This also applies to building group configurations with subtasks that are python classes.
```yaml
group: scrolls
task:
- task: scrolls_qasper
class: !function task.Qasper
- task: scrolls_quality
class: !function task.QuALITY
- task: scrolls_narrativeqa
class: !function task.NarrativeQA
...
```
## Beautifying Table Display
To avoid conflict, each task needs to be registered with a unique name. Because of this, slight variations of task are still counted as unique tasks and need to be named uniquely. This could be done by appending an additional naming that may refer to the variation such as in MMLU where the template used to evaluated for flan are differentiated from the default by the prefix `mmlu_flan_*`. Printing the full task names can easily clutter the results table at the end of the evaluation especially when you have a long list of tasks or are using a benchmark that comprises of many tasks. To make it more legible, you can use `task_alias` and `group_alias` to provide an alternative task name and group name that will be printed.
......
......@@ -50,7 +50,7 @@ Scoring details:
- **doc_to_decontamination_query** (`str`, *optional*) — Query for decontamination if `should_decontaminate` is True. If `should_decontaminate` is True but `doc_to_decontamination_query` is `None`, `doc_to_decontamination_query` will follow `doc_to_text`.
Other:
- **metadata** (`Union[str, list]`, *optional*) — An optional field where arbitrary metadata can be passed. A good example would be `version` that is used to denote the version of the yaml config.
- **metadata** (`dict`, *optional*) — An optional field where arbitrary metadata can be passed. Most tasks should include a `version` key in this field that is used to denote the version of the yaml config. Other special metadata keys are: `num_fewshot`, to override the printed `n-shot` table column for a task.
## Filters
......
......@@ -10,7 +10,7 @@ from typing import Union
import numpy as np
from lm_eval import evaluator, utils
from lm_eval.tasks import TaskManager
from lm_eval.tasks import TaskManager, include_path, initialize_tasks
from lm_eval.utils import make_table
......@@ -142,6 +142,13 @@ def parse_eval_args() -> argparse.Namespace:
metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
help="Controls the reported logging error level. Set to DEBUG when testing + adding new task configurations for comprehensive log output.",
)
parser.add_argument(
"--predict_only",
"-x",
action="store_true",
default=False,
help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
)
return parser.parse_args()
......@@ -155,7 +162,12 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
eval_logger.info(f"Verbosity set to {args.verbosity}")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# initialize_tasks(args.verbosity)
if args.predict_only:
args.log_samples = True
if (args.log_samples or args.predict_only) and not args.output_path:
assert args.output_path, "Specify --output_path"
initialize_tasks(args.verbosity)
task_manager = TaskManager(args.verbosity, include_path=args.include_path)
if args.limit:
......@@ -169,7 +181,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
sys.exit()
elif args.tasks == "list":
eval_logger.info(
"Available Tasks:\n - {}".format("\n - ".join(task_manager.all_tasks()))
"Available Tasks:\n - {}".format("\n - ".join(task_manager.all_tasks))
)
else:
if os.path.isdir(args.tasks):
......@@ -181,20 +193,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
config = utils.load_yaml_config(yaml_file)
loaded_task_list.append(config)
else:
input_task_list = args.tasks.split(",")
loaded_task_list = utils.pattern_match(
input_task_list, task_manager.all_tasks()
)
for task in [
task for task in input_task_list if task not in loaded_task_list
]:
task_list = args.tasks.split(",")
task_names = task_manager.match_tasks(task_list)
for task in [task for task in task_list if task not in task_names]:
if os.path.isfile(task):
config = utils.load_yaml_config(task)
loaded_task_list.append(config)
task_missing = [
task
for task in input_task_list
if task not in loaded_task_list and "*" not in task
task for task in task_list if task not in task_names and "*" not in task
] # we don't want errors if a wildcard ("*") task name was used
if task_missing:
......@@ -224,14 +230,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
else:
path.mkdir(parents=True, exist_ok=True)
output_path_file = path.joinpath("results.json")
elif args.log_samples and not args.output_path:
assert args.output_path, "Specify --output_path"
eval_logger.info(f"Selected Tasks: {loaded_task_list}")
eval_logger.info(f"Selected Tasks: {task_names}")
eval_logger.info("Loading selected tasks...")
all_tasks = task_manager.load_task_or_group(loaded_task_list)
results = evaluator.simple_evaluate(
model=args.model,
model_args=args.model_args,
......@@ -247,6 +249,8 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
write_out=args.write_out,
log_samples=args.log_samples,
gen_kwargs=args.gen_kwargs,
task_manager=task_manager,
predict_only=args.predict_only,
)
if results is not None:
......@@ -261,7 +265,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
if args.output_path:
output_path_file.open("w").write(dumped)
output_path_file.open("w", encoding="utf-8").write(dumped)
if args.log_samples:
for task_name, config in results["configs"].items():
......
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List
from datasets import Dataset
from typing import Callable, Iterable, List, Union
from lm_eval.api.instance import Instance
class Filter:
class Filter(ABC):
"""
Filter classes operate on a per-task level.
They take all model outputs (`instance.resps` for all `task.instances`)
......@@ -15,12 +14,13 @@ class Filter:
"""
def __init__(self, *args, **kwargs) -> None:
def __init__(self, **kwargs) -> None:
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
def apply(self, resps, docs):
@abstractmethod
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
"""
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
......@@ -40,15 +40,15 @@ class FilterEnsemble:
"""
name: str
filters: List[Filter]
filters: List[Callable[[], Filter]]
def apply(self, instances: List[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
resps, docs = list(resps), list(docs)
def apply(self, instances: List[Instance], docs: List[Dataset]) -> None:
resps = [
inst.resps for inst in instances
] # operate just on the model responses
for f in self.filters:
# apply filters in sequence
resps = f.apply(resps, docs)
resps = f().apply(resps, docs)
# add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
......
......@@ -4,7 +4,12 @@ from typing import Literal, Tuple
@dataclass
class Instance:
request_type: Literal["loglikelihood", "loglikelihood_rolling", "generate_until"]
request_type: Literal[
"loglikelihood",
"loglikelihood_rolling",
"generate_until",
"multiple_choice",
]
doc: dict
arguments: tuple
idx: int
......
......@@ -16,6 +16,11 @@ eval_logger = logging.getLogger("lm-eval")
# Register Aggregations First
@register_aggregation("bypass")
def bypass_agg(arr):
return 999
@register_aggregation("mean")
def mean(arr):
return sum(arr) / len(arr)
......@@ -241,6 +246,16 @@ def mean_stderr(arr):
return sample_stddev(arr) / math.sqrt(len(arr))
@register_metric(
metric="bypass",
higher_is_better=True,
output_type=["loglikelihood", "multiple_choice", "generate_until"],
aggregation="bypass",
)
def bypass(items):
return None
@register_metric(
metric="mcc",
higher_is_better=True,
......
......@@ -152,18 +152,14 @@ def get_aggregation(name):
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(
"{} not a registered aggregation metric!".format(name),
)
eval_logger.warning(f"{name} not a registered aggregation metric!")
def get_metric_aggregation(name):
try:
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(
"{} metric is not assigned a default aggregation!".format(name),
)
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
def is_higher_better(metric_name):
......
......@@ -5,6 +5,7 @@ import random
import re
from collections.abc import Callable
from dataclasses import asdict, dataclass
from inspect import getsource
from typing import Any, List, Literal, Tuple, Union
import datasets
......@@ -37,7 +38,6 @@ ALL_OUTPUT_TYPES = [
"generate_until",
]
eval_logger = logging.getLogger("lm-eval")
......@@ -90,16 +90,25 @@ class TaskConfig(dict):
num_fewshot: int = None
# scoring options
metric_list: list = None
output_type: str = "generate_until"
output_type: Literal[
"loglikelihood",
"loglikelihood_rolling",
"generate_until",
"multiple_choice",
] = "generate_until"
generation_kwargs: dict = None
repeats: int = 1
filter_list: Union[str, list] = None
should_decontaminate: bool = False
doc_to_decontamination_query: str = None
<<<<<<< HEAD
weight_by_size: bool = False
metadata: Union[
str, list
] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
=======
metadata: dict = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
>>>>>>> 7411947112117e0339fe207fb620a70bcec22690
def __post_init__(self) -> None:
if self.generation_kwargs is not None:
......@@ -126,17 +135,20 @@ class TaskConfig(dict):
"do_sample": False,
}
<<<<<<< HEAD
# TODO: how to make TaskConfigs be de- and re-serializable, even when using the !function constructor?
# if self.dataset_kwargs is None:
# self.dataset_kwargs = {"trust_remote_code": True}
=======
>>>>>>> 7411947112117e0339fe207fb620a70bcec22690
def __getitem__(self, item):
return getattr(self, item)
def __setitem__(self, item, value):
return setattr(self, item, value)
def to_dict(self, keep_callable=False):
def to_dict(self, keep_callable: bool = False) -> dict:
"""dumps the current config as a dictionary object, as a printable format.
null fields will not be printed.
Used for dumping results alongside full task configuration
......@@ -151,14 +163,34 @@ class TaskConfig(dict):
for k, v in list(cfg_dict.items()):
if v is None:
cfg_dict.pop(k)
elif isinstance(v, Callable):
if keep_callable:
cfg_dict[k] = v
else:
# TODO: this should handle Promptsource template objects as a separate case?
cfg_dict[k] = str(v)
elif k == "metric_list":
for metric_dict in v:
for metric_key, metric_value in metric_dict.items():
if callable(metric_value):
metric_dict[metric_key] = self.serialize_function(
metric_value, keep_callable=keep_callable
)
cfg_dict[k] = v
elif callable(v):
cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
return cfg_dict
def serialize_function(
self, value: Union[Callable, str], keep_callable=False
) -> Union[Callable, str]:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
Otherwise, attempts to return the source code of the callable using 'getsource'.
"""
if keep_callable:
return value
else:
try:
return getsource(value)
except (TypeError, OSError):
return str(value)
class Task(abc.ABC):
"""A task represents an entire benchmark including its dataset, problems,
......@@ -303,7 +335,7 @@ class Task(abc.ABC):
return self.validation_docs()
else:
eval_logger.warning(
"has_training_docs and has_validation_docs are False"
f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False"
", using test_docs as fewshot_docs but this is not recommended."
)
return self.test_docs()
......@@ -434,6 +466,9 @@ class Task(abc.ABC):
"""
pass
def get_config(self, key: str) -> Any:
return getattr(self._config, key, None)
@classmethod
def count_bytes(cls, doc):
"""Used for byte-level perplexity metrics in rolling loglikelihood"""
......@@ -508,7 +543,7 @@ class Task(abc.ABC):
def apply_filters(self):
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances, None)
f.apply(self._instances)
else:
eval_logger.warning("No filter defined, passing through instances")
return self._instances
......@@ -619,7 +654,7 @@ class ConfigurableTask(Task):
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = get_metric_aggregation(metric_name)
eval_logger.warning(
f"[Task: {self._config.task}] metric {metric_name} is defined, but aggregation is not. "
f"[Task: {self.config.task}] metric {metric_name} is defined, but aggregation is not. "
f"using default "
f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
)
......@@ -631,7 +666,7 @@ class ConfigurableTask(Task):
]
else:
eval_logger.warning(
f"[Task: {self._config.task}] metric {metric_name} is defined, but higher_is_better is not. "
f"[Task: {self.config.task}] metric {metric_name} is defined, but higher_is_better is not. "
f"using default "
f"higher_is_better={is_higher_better(metric_name)}"
)
......@@ -644,16 +679,15 @@ class ConfigurableTask(Task):
if self.config.filter_list is not None:
self._filters = []
for filter_config in self.config.filter_list:
for filter_pipeline in filter_config:
filter_name = filter_config["name"]
filter_functions = filter_config["filter"]
components = []
for function in filter_functions:
kwargs = {
key: function[key] for key in function if key != "function"
}
components.append([function["function"], kwargs])
filter_pipeline = build_filter_ensemble(filter_name, components)
filter_name = filter_config["name"]
filter_functions = filter_config["filter"]
components = []
for function in filter_functions:
kwargs = {
key: function[key] for key in function if key != "function"
}
components.append([function["function"], kwargs])
filter_pipeline = build_filter_ensemble(filter_name, components)
self._filters.append(filter_pipeline)
else:
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
......@@ -831,7 +865,7 @@ class ConfigurableTask(Task):
def apply_filters(self):
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances, self.task_docs)
f.apply(self._instances)
else:
eval_logger.warning("No filter defined, passing through instances")
return self._instances
......@@ -1226,12 +1260,46 @@ class ConfigurableTask(Task):
return result_dict
def aggregation(self):
def aggregation(self) -> dict:
return self._aggregation_list
def higher_is_better(self):
def higher_is_better(self) -> dict:
return self._higher_is_better
def get_config(self, key: str) -> Any:
return getattr(self._config, key, None)
def override_metric(self, metric_name: str) -> None:
"""
Override the default metrics used for evaluation with custom metrics.
Parameters:
- metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
"""
(
self._metric_fn_list,
self._aggregation_list,
self._metric_fn_kwargs,
self._higher_is_better,
) = ({}, {}, {}, {})
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
self._metric_fn_kwargs[metric_name] = {}
setattr(self._config, "metric_list", [{"metric": metric_name}])
setattr(self._config, "process_results", None)
def override_config(
self, key: str = None, value: Any = None, update: bool = False
) -> None:
if update:
current_value = getattr(self._config, key)
assert isinstance(current_value, dict)
current_value.update(value)
setattr(self._config, key, current_value)
else:
setattr(self._config, key, value)
class MultipleChoiceTask(Task):
OUTPUT_TYPE: str = "loglikelihood"
......
......@@ -30,7 +30,9 @@ class Archive:
self.cctx = zstandard.ZstdCompressor(level=compression_level)
self.compressor = self.cctx.stream_writer(self.fh)
def add_data(self, data, meta={}) -> None:
def add_data(self, data, meta=None) -> None:
if meta is None:
meta = {}
self.compressor.write(
json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
"UTF-8"
......@@ -108,7 +110,7 @@ class TextReader:
def read_tqdm(self, update_frequency: int = 10000):
current_file_position = 0
line_counter = 0
with open(self.file_path, "r") as fh, tqdm.tqdm(
with open(self.file_path, "r", encoding="utf-8") as fh, tqdm.tqdm(
total=os.path.getsize(self.file_path),
dynamic_ncols=True,
unit="byte",
......
......@@ -38,7 +38,7 @@ def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> d
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
info_dict_path = os.path.join(ngrams_path, "info.json")
info_dict = json.load(open(info_dict_path, "r"))
info_dict = json.load(open(info_dict_path, "r", encoding="utf-8"))
ngrams_n_size = info_dict["ngram_size"]
janitor = Janitor()
......
......@@ -4,20 +4,24 @@ import collections
import torch
import logging
import numpy as np
import lm_eval.api
import lm_eval.tasks
import lm_eval.models
import lm_eval.api.metrics
import lm_eval.api.registry
from lm_eval.tasks import (
get_task_dict,
TaskManager
)
from lm_eval.utils import (
positional_deprecated,
run_task_tests,
get_git_commit_hash,
simple_parse_args_string,
eval_logger,
eval_logger
)
......@@ -25,7 +29,7 @@ from lm_eval.utils import (
def simple_evaluate(
model,
model_args=None,
tasks=[],
tasks=None,
num_fewshot=None,
batch_size=None,
max_batch_size=None,
......@@ -38,7 +42,9 @@ def simple_evaluate(
write_out: bool = False,
log_samples: bool = True,
gen_kwargs: str = None,
weight_by_size: bool = False,
task_manager: TaskManager = None,
verbosity: str = "INFO",
predict_only: bool = False,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -47,8 +53,8 @@ def simple_evaluate(
:param model_args: Optional[str]
String arguments for each model class, see LM.create_from_arg_string.
Ignored if `model` argument is a LM object.
:param tasks: list[Task]
List of Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
:param tasks: list[Union[str, dict, Task]]
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
:param num_fewshot: int
Number of examples in few-shot context
:param batch_size: int or str, optional
......@@ -72,6 +78,9 @@ def simple_evaluate(
:param gen_kwargs: str
String arguments for model generation
Ignored for all tasks with loglikelihood output_type
:param predict_only: bool
If true only model outputs will be generated and returned. Metrics will not be evaluated
:return
Dictionary of results
"""
......@@ -81,6 +90,10 @@ def simple_evaluate(
1234
) # TODO: this may affect training runs that are run with evaluation mid-run.
eval_logger.setLevel(getattr(logging, f"{verbosity}"))
if tasks is None:
tasks = []
assert (
tasks != []
), "No tasks specified, or no tasks found. Please verify the task names."
......@@ -88,7 +101,7 @@ def simple_evaluate(
if gen_kwargs is not None:
gen_kwargs = simple_parse_args_string(gen_kwargs)
eval_logger.warning(
"generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks."
"generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. Ensure 'do_sample=True' for non-greedy decoding!"
)
if gen_kwargs == "":
gen_kwargs = None
......@@ -120,30 +133,45 @@ def simple_evaluate(
+ ".db",
)
task_dict = tasks
if task_manager is None:
task_manager = TaskManager(verbosity)
eval_logger.info(
"get_task_dict has been updated to accept an optional argument, `task_manager`"
"Read more here: https://github.com/EleutherAI/lm-evaluation-harness/blob/recursive-groups/docs/interface.md#external-library-usage"
)
task_dict = get_task_dict(tasks, task_manager)
for task_name in task_dict.keys():
task_obj = task_dict[task_name]
if type(task_obj) == tuple:
if isinstance(task_obj, tuple):
_, task_obj = task_obj
if task_obj is None:
continue
config = task_obj._config
if config["output_type"] == "generate_until" and gen_kwargs is not None:
config["generation_kwargs"].update(gen_kwargs)
if task_obj.get_config("output_type") == "generate_until":
if gen_kwargs is not None:
task_obj.override_config(
key="generation_kwargs", value=gen_kwargs, update=True
)
if predict_only:
log_samples = True
eval_logger.info(
f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
)
# we have to change the class properties post-hoc. This is pretty hacky.
task_obj.override_metric(metric_name="bypass")
if num_fewshot is not None:
if config["num_fewshot"] == 0:
if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
eval_logger.info(
f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
)
else:
default_num_fewshot = config["num_fewshot"]
eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
)
task_obj._config["num_fewshot"] = num_fewshot
task_obj.override_config(key="num_fewshot", value=num_fewshot)
if check_integrity:
run_task_tests(task_list=tasks)
......@@ -156,7 +184,7 @@ def simple_evaluate(
decontamination_ngrams_path=decontamination_ngrams_path,
write_out=write_out,
log_samples=log_samples,
weight_by_size=weight_by_size,
verbosity=verbosity,
)
if lm.rank == 0:
......@@ -199,7 +227,7 @@ def evaluate(
decontamination_ngrams_path=None,
write_out: bool = False,
log_samples: bool = True,
weight_by_size: bool = False,
verbosity: str = "INFO",
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -219,8 +247,17 @@ def evaluate(
Dictionary of results
"""
eval_logger.setLevel(getattr(logging, f"{verbosity}"))
# decontaminate = decontamination_ngrams_path is not None
for task_name, task in task_dict.items():
if isinstance(task, tuple):
_, task = task
if not log_samples:
assert (
"bypass" not in getattr(task, "_metric_fn_list", {}).keys()
), f"log_samples must be True for 'bypass' only tasks: {task_name}"
# stores the final result for each task, for each metric/filter pair.
results = collections.defaultdict(dict)
# Tracks each task's version.
......@@ -245,7 +282,7 @@ def evaluate(
# get lists of each type of request
for task_name, task in task_dict.items():
if type(task) == tuple:
if isinstance(task, tuple):
group_name, task = task
task_hierarchy[group_name].append(task_name)
versions[group_name] = "N/A"
......@@ -261,9 +298,12 @@ def evaluate(
configs[task_name] = dict(task.dump_config())
if "num_fewshot" in configs[task_name]:
n_shot = configs[task_name]["num_fewshot"]
if configs[task_name]["metadata"]:
n_shot = configs[task_name]["metadata"].get("num_fewshot", None)
if not n_shot:
n_shot = configs[task_name]["num_fewshot"]
else:
n_shot = 0
n_shot = 0 # TODO: is this always right?
num_fewshot[task_name] = n_shot
if "task_alias" in configs[task_name]:
......@@ -319,7 +359,7 @@ def evaluate(
### Run LM on inputs, get all outputs ###
# execute each type of request
for reqtype, reqs in requests.items():
eval_logger.info("Running {} requests".format(reqtype))
eval_logger.info(f"Running {reqtype} requests")
# create `K` copies of each request `req` based off `K = req.repeats`
cloned_reqs = []
for req in reqs:
......@@ -342,7 +382,7 @@ def evaluate(
### Postprocess outputs ###
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
for task_name, task in task_dict.items():
if type(task) == tuple:
if isinstance(task, tuple):
group, task = task
if task is None:
continue
......@@ -353,7 +393,7 @@ def evaluate(
# unpack results and sort back in order and return control to Task
for task_name, task in task_dict.items():
if type(task) == tuple:
if isinstance(task, tuple):
group, task = task
if task is None:
continue
......@@ -404,7 +444,7 @@ def evaluate(
vals_torch = collections.defaultdict(list)
for (task_name, key, metric), items in vals.items():
numitem = 0
if type(items[0]) == tuple:
if isinstance(items[0], tuple):
numitem = len(items[0])
if isinstance(items[0], (str, list, tuple)):
......@@ -450,7 +490,7 @@ def evaluate(
task = task_dict[task_name]
metric_key = metric + "," + key
if type(task) == tuple:
if isinstance(task, tuple):
group_name, task = task
else:
group_name = None
......@@ -477,7 +517,8 @@ def evaluate(
if bool(results):
for group, task_list in reversed(task_hierarchy.items()):
if task_list == []:
total_size = results[group]["samples"]
# TODO: No samples when bypass
total_size = results[group].get("samples", 999)
else:
total_size = 0
......@@ -487,16 +528,7 @@ def evaluate(
if "alias" in metrics:
metrics.pop("alias")
if ("weight_by_size" in configs) and configs[task]["weight_by_size"]:
current_size = metrics.pop("samples")
else:
metrics.pop("samples")
current_size = 1
# TODO: Tasks like brier score for individual
# tasks have no stderr since the score is
# itself an aggregation. But it's possible to
# calculate the stderr over groups
current_size = metrics.pop("samples")
all_stderr = []
for metric in [
......@@ -518,7 +550,7 @@ def evaluate(
+ metric_score * current_size
) / (total_size + current_size)
# $$s_z^2 = \frac{(n-1) s_x^2 + (m-1) s_y^2}{n+m-1} + \frac{nm(\bar x - \bar y)^2}{(n+m)(n+m-1)}.$$
if var_score == "N/A":
if var_score == "N/A" or results[group][stderr] == "N/A":
results[group][stderr] = "N/A"
else:
results[group][stderr] = (
......@@ -617,7 +649,7 @@ def evaluate(
for group_name, task_list in task_hierarchy.items():
if task_list != []:
num_fewshot[group_name] = num_fewshot[task_list[0]]
num_fewshot[group_name] = num_fewshot[task_list[0]] # TODO: validate this
results_dict = {
"results": dict(results_agg.items()),
......
from typing import List, Union
from functools import partial
from lm_eval.api.filter import FilterEnsemble
from . import selection
from . import extraction
......@@ -20,24 +23,25 @@ FILTER_REGISTRY = {
}
def get_filter(filter_name):
def get_filter(filter_name: str) -> Union[type, str]:
if filter_name in FILTER_REGISTRY:
return FILTER_REGISTRY[filter_name]
else:
return filter_name
def build_filter_ensemble(filter_name, components):
def build_filter_ensemble(
filter_name: str, components: List[List[str]]
) -> FilterEnsemble:
"""
Create a filtering pipeline.
"""
filters = []
for function, kwargs in components:
if kwargs is None:
f = get_filter(function)()
else:
# create a filter given its name in the registry
f = get_filter(function)(**kwargs) # TODO: pass kwargs to filters properly
kwargs = {}
# create a filter given its name in the registry
f = partial(get_filter(function), **kwargs)
# add the filter as a pipeline step
filters.append(f)
......
......@@ -17,12 +17,14 @@ class TakeFirstFilter(Filter):
class TakeKFilter(Filter):
def __init__(self, *args, **kwargs) -> None:
def __init__(self, **kwargs) -> None:
self.k = kwargs.pop("k")
super().__init__(*args, **kwargs)
super().__init__(**kwargs)
def apply(self, resps, docs):
# need resp to be subscriptable to check below
resps = list(resps)
# check we have at least k responses per doc, else we can't take the first k
assert (
len(resps[0]) >= self.k
......
......@@ -24,7 +24,7 @@ class UppercaseFilter(Filter):
class MapFilter(Filter):
def __init__(self, mapping_dict: dict = {}, default_value=None) -> None:
def __init__(self, mapping_dict: dict = None, default_value=None) -> None:
"""
Initializes the MapFilter with a given mapping dictionary and default value.
......@@ -37,6 +37,8 @@ class MapFilter(Filter):
Example:
mapper = MapFilter({'A': 1, 'B': 2}, default_value=0)
"""
if mapping_dict is None:
mapping_dict = {}
assert isinstance(
mapping_dict, dict
), "Provided mapping_dict is not a dictionary"
......
......@@ -6,5 +6,6 @@ from . import anthropic_llms
from . import gguf
from . import vllm_causallms
from . import mamba_lm
from . import optimum_lm
# TODO: implement __all__
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment