Commit 0d1ef037 authored by lintangsutawika's avatar lintangsutawika
Browse files

solved merge conflict

parents aa44be3f ada4a31d
...@@ -56,7 +56,7 @@ jobs: ...@@ -56,7 +56,7 @@ jobs:
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true' if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu pip install -e '.[dev]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies # Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
......
...@@ -17,29 +17,22 @@ jobs: ...@@ -17,29 +17,22 @@ jobs:
linter: linter:
name: Linters name: Linters
runs-on: ubuntu-latest runs-on: ubuntu-latest
timeout-minutes: 20 timeout-minutes: 5
steps: steps:
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Set up Python 3.8 - name: Set up Python 3.8
uses: actions/setup-python@v4 uses: actions/setup-python@v5
with: with:
python-version: 3.8 python-version: 3.8
cache: pip cache: pip
cache-dependency-path: setup.py cache-dependency-path: pyproject.toml
- name: Install dependencies
run: pip install -e '.[linting,testing]' --extra-index-url https://download.pytorch.org/whl/cpu ; export SKIP=no-commit-to-branch # env var deactivates --no-commit-to-branch
- name: Pre-Commit - name: Pre-Commit
env:
SKIP: "no-commit-to-branch,mypy"
uses: pre-commit/action@v3.0.0 uses: pre-commit/action@v3.0.0
- name: Lint with pylint
run: python -m pylint --disable=all -e W0311 --jobs=0 --indent-string=' ' **/*.py
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=F,E9,E71,E72,E501,E112,E113,W6 --extend-ignore=F541 --show-source --statistics --exit-zero
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# # mypy turned off for now # # mypy turned off for now
# - name: Lint with mypy # - name: Lint with mypy
# run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable # run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable
...@@ -53,22 +46,22 @@ jobs: ...@@ -53,22 +46,22 @@ jobs:
timeout-minutes: 30 timeout-minutes: 30
steps: steps:
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4 uses: actions/setup-python@v5
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
cache: pip cache: pip
cache-dependency-path: setup.py cache-dependency-path: pyproject.toml
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu pip install -e '.[dev,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies # Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest - name: Test with pytest
run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/tests_master --ignore=tests/extra run: python -m pytest --showlocals -s -vv -n=auto
- name: Archive artifacts - name: Archive artifacts
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
with: with:
......
...@@ -17,6 +17,7 @@ repos: ...@@ -17,6 +17,7 @@ repos:
- id: detect-private-key - id: detect-private-key
- id: end-of-file-fixer - id: end-of-file-fixer
- id: no-commit-to-branch - id: no-commit-to-branch
always_run: false
- id: requirements-txt-fixer - id: requirements-txt-fixer
- id: trailing-whitespace - id: trailing-whitespace
args: [--markdown-linebreak-ext=md] args: [--markdown-linebreak-ext=md]
...@@ -26,14 +27,16 @@ repos: ...@@ -26,14 +27,16 @@ repos:
args: [--remove] args: [--remove]
- id: mixed-line-ending - id: mixed-line-ending
args: [--fix=lf] args: [--fix=lf]
- repo: https://github.com/pycqa/flake8 - repo: https://github.com/astral-sh/ruff-pre-commit
rev: 3.7.9 # Ruff version.
rev: v0.1.8
hooks: hooks:
- id: flake8 # Run the linter.
- repo: https://github.com/psf/black - id: ruff
rev: 22.3.0 args:
hooks: - --fix
- id: black # Run the formatter.
- id: ruff-format
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.1.0 rev: v2.1.0
hooks: hooks:
......
@software{eval-harness, @misc{eval-harness,
author = {Gao, Leo and author = {Gao, Leo and Tow, Jonathan and Abbasi, Baber and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and Le Noac'h, Alain and Li, Haonan and McDonell, Kyle and Muennighoff, Niklas and Ociepa, Chris and Phang, Jason and Reynolds, Laria and Schoelkopf, Hailey and Skowron, Aviya and Sutawika, Lintang and Tang, Eric and Thite, Anish and Wang, Ben and Wang, Kevin and Zou, Andy},
Tow, Jonathan and
Biderman, Stella and
Black, Sid and
DiPofi, Anthony and
Foster, Charles and
Golding, Laurence and
Hsu, Jeffrey and
McDonell, Kyle and
Muennighoff, Niklas and
Phang, Jason and
Reynolds, Laria and
Tang, Eric and
Thite, Anish and
Wang, Ben and
Wang, Kevin and
Zou, Andy},
title = {A framework for few-shot language model evaluation}, title = {A framework for few-shot language model evaluation},
month = sep, month = 12,
year = 2021, year = 2023,
publisher = {Zenodo}, publisher = {Zenodo},
version = {v0.0.1}, version = {v0.4.0},
doi = {10.5281/zenodo.5371628}, doi = {10.5281/zenodo.10256836},
url = {https://doi.org/10.5281/zenodo.5371628} url = {https://zenodo.org/records/10256836}
} }
* @haileyschoelkopf @lintangsutawika @StellaAthena * @haileyschoelkopf @lintangsutawika
This diff is collapsed.
...@@ -4,7 +4,7 @@ Welcome to the docs for the LM Evaluation Harness! ...@@ -4,7 +4,7 @@ Welcome to the docs for the LM Evaluation Harness!
## Table of Contents ## Table of Contents
* To learn about the public interface of the library, as well as how to evaluate via the commandline or as integrated into an external library, see the [Interface](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/user_guide.md) * To learn about the public interface of the library, as well as how to evaluate via the commandline or as integrated into an external library, see the [Interface](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/interface.md)
* To learn how to add a new library, API, or model type to the library, as well as a quick explainer on the types of ways to evaluate an LM, see the [Model Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/model_guide.md). * To learn how to add a new library, API, or model type to the library, as well as a quick explainer on the types of ways to evaluate an LM, see the [Model Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/model_guide.md).
* For a crash course on adding new tasks to the library, see our [New Task Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/new_task_guide.md). * For a crash course on adding new tasks to the library, see our [New Task Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/new_task_guide.md).
* To learn more about pushing the limits of task configuration that the Eval Harness supports, see the [Task Configuration Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/task_guide.md). * To learn more about pushing the limits of task configuration that the Eval Harness supports, see the [Task Configuration Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/task_guide.md).
...@@ -46,16 +46,6 @@ dataset_name: ... # the dataset configuration to use. Leave `null` if your datas ...@@ -46,16 +46,6 @@ dataset_name: ... # the dataset configuration to use. Leave `null` if your datas
dataset_kwargs: null # any extra keyword arguments that should be passed to the dataset constructor, e.g. `data_dir`. dataset_kwargs: null # any extra keyword arguments that should be passed to the dataset constructor, e.g. `data_dir`.
``` ```
------------------------------
**Tip:** To load a local dataset for evaluation, you can specify data files in the `dataset_kwargs` field, such as the following for JSON files:
```
dataset_path: json
dataset_name: null
dataset_kwargs:
data_files: /path/to/my/json
```
-------------------------------
Next, we'd like to tell our task what the dataset's train, validation, and test splits are named, if they exist: Next, we'd like to tell our task what the dataset's train, validation, and test splits are named, if they exist:
```yaml ```yaml
...@@ -99,6 +89,36 @@ Now, in our YAML config file we'll use the `!function` constructor, and tell the ...@@ -99,6 +89,36 @@ Now, in our YAML config file we'll use the `!function` constructor, and tell the
process_docs: !function utils.process_docs process_docs: !function utils.process_docs
``` ```
### Using Local Datasets
To load a local dataset for evaluation, you can specify data files in the `dataset_kwargs` field, such as the following for JSON files:
```
dataset_path: json
dataset_name: null
dataset_kwargs:
data_files: /path/to/my/json
```
Or with files already split into separate directories:
```
dataset_path: arrow
dataset_kwargs:
data_files:
train: /path/to/arrow/train/data-00000-of-00001.arrow
validation: /path/to/arrow/validation/data-00000-of-00001.arrow
```
Alternatively, if you have previously downloaded a dataset from huggingface hub (using `save_to_disk()`) and wish to use the local files, you will need to use `data_dir` under `dataset_kwargs` to point to where the directory is.
```
dataset_path: hellaswag
dataset_kwargs:
data_dir: hellaswag_local/
```
You can also set `dataset_path` as a directory path in your local system. This will assume that there is a loading script with the same name as the directory. [See datasets docs](https://huggingface.co/docs/datasets/loading#local-loading-script).
## Writing a Prompt Template ## Writing a Prompt Template
The next thing we need to do is decide what format to use when presenting the data to the LM. This is our **prompt**, where we'll define both an input and output format. The next thing we need to do is decide what format to use when presenting the data to the LM. This is our **prompt**, where we'll define both an input and output format.
...@@ -315,6 +335,25 @@ python -m scripts.write_out \ ...@@ -315,6 +335,25 @@ python -m scripts.write_out \
Open the file specified at the `--output_base_path <path>` and ensure it passes Open the file specified at the `--output_base_path <path>` and ensure it passes
a simple eye test. a simple eye test.
## Versioning
One key feature in LM Evaluation Harness is the ability to version tasks--that is, mark them with a specific version number that can be bumped whenever a breaking change is made.
This version info can be provided by adding the following to your new task config file:
```
metadata:
version: 0
```
Now, whenever a change needs to be made to your task in the future, please increase the version number by 1 so that users can differentiate the different task iterations and versions.
If you are incrementing a task's version, please also consider adding a changelog to the task's README.md noting the date, PR number, what version you have updated to, and a one-liner describing the change.
for example,
* \[Dec 25, 2023\] (PR #999) Version 0.0 -> 1.0: Fixed a bug with answer extraction that led to underestimated performance.
## Checking performance + equivalence ## Checking performance + equivalence
It's now time to check models' performance on your task! In the evaluation harness, we intend to support a wide range of evaluation tasks and setups, but prioritize the inclusion of already-proven benchmarks following the precise evaluation setups in the literature where possible. It's now time to check models' performance on your task! In the evaluation harness, we intend to support a wide range of evaluation tasks and setups, but prioritize the inclusion of already-proven benchmarks following the precise evaluation setups in the literature where possible.
...@@ -340,4 +379,4 @@ It is recommended to include a filled-out copy of this checklist in the README.m ...@@ -340,4 +379,4 @@ It is recommended to include a filled-out copy of this checklist in the README.m
## Submitting your task ## Submitting your task
You're all set! Now push your work and make a pull request to the `big-refactor` branch! Thanks for the contribution :). If there are any questions, please leave a message in the `#lm-thunderdome` channel on the EAI discord! You're all set! Now push your work and make a pull request to the `main` branch! Thanks for the contribution :). If there are any questions, please leave a message in the `#lm-thunderdome` channel on the EAI discord!
...@@ -219,6 +219,49 @@ Aggregation functions: ...@@ -219,6 +219,49 @@ Aggregation functions:
* `weighted_perplexity` * `weighted_perplexity`
* `bits_per_byte` * `bits_per_byte`
### Adding a Multiple Choice Metric
Adding a multiple choice metric has a few steps. To get it working you need to:
1. register a metric function
2. register an aggregation function
3. update the `Task` definition to make sure the correct arguments are passed
The default metric and aggregation functions are in `lm_eval/api/metrics.py`, and you can add a function there if it's for general use. The metrics are towards the bottom of the file and look like this:
@register_metric(
metric="mcc",
higher_is_better=True,
output_type="multiple_choice",
aggregation="matthews_corrcoef",
)
def mcc_fn(items): # This is a passthrough function
return items
Note that many of these are passthrough functions, and for multiple choice (at least) this function is never actually called.
Aggregation functions are defined towards the top of the file, here's an example:
@register_aggregation("matthews_corrcoef")
def matthews_corrcoef(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
return sklearn.metrics.matthews_corrcoef(golds, preds)
This function returns a single numeric value. The input is defined in `Task.process_results` in `lm_eval/api/task.py`. There's a section that looks like this:
result_dict = {
**({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
**({"exact_match": exact_match} if "exact_match" in use_metric else {}),
}
The value here determines the input to the aggregation function, though the name used matches the metric function. These metrics all have simple needs and just need the accuracy or gold and predicted values, but immediately below this there are examples of metrics with more complicated needs you can use as reference.
## Good Reference Tasks ## Good Reference Tasks
...@@ -258,6 +301,23 @@ task: ...@@ -258,6 +301,23 @@ task:
- hendrycksTest* - hendrycksTest*
``` ```
It is also possible to list an existing task in your benchmark configuration with some adjustments. For example, a few tasks from mmlu is included `multimedqa`. There, the `task_alias` and `group_alias` (See [here](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/new_task_guide.md#beautifying-table-display) for more details) are modified to suit the benchmark.
```yaml
group: multimedqa
task:
- pubmedqa
- medmcqa
- medqa_4options
- task: mmlu_anatomy
task_alias: "anatomy (mmlu)"
group_alias: null
- task: mmlu_clinical_knowledge
task_alias: "clinical_knowledge (mmlu)"
group_alias: null
...
```
Alternatively, benchmarks can have tasks that are customizable for each task. They can be defined like how a yaml task is usually set. Alternatively, benchmarks can have tasks that are customizable for each task. They can be defined like how a yaml task is usually set.
```yaml ```yaml
......
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualizing Results in Zeno\n",
"\n",
"Benchmarking your models is the first step towards making sure your model performs well.\n",
"However, looking at the data behind the benchmark, slicing the data into subsets, and comparing models on individual instances can help you even more in evaluating and quantifying the behavior of your AI system.\n",
"\n",
"All of this can be done in [Zeno](https://zenoml.com)!\n",
"Zeno is super easy to use with the eval harness, let's explore how you can easily upload and visualize your eval results.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install this project if you did not already do that. This is all that needs to be installed for you to be able to visualize your data in Zeno!\n",
"!pip install -e ..\n",
"!pip install -e ..[zeno]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Run the Eval Harness\n",
"\n",
"To visualize the results, run the eval harness with the `log_samples` and `output_path` flags. We expect `output_path` to contain multiple folders that represent individual model names. You can thus run your evaluation on any number of tasks and models and upload all of the results as projects on Zeno.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!lm_eval \\\n",
" --model hf \\\n",
" --model_args pretrained=EleutherAI/gpt-neo-2.7B \\\n",
" --tasks hellaswag,wikitext \\\n",
" --batch_size 8 \\\n",
" --device mps \\\n",
" --log_samples \\\n",
" --output_path output/gpt-neo-2.7B \\\n",
" --limit 10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Set your API Key\n",
"\n",
"This is so you can be authenticated with Zeno.\n",
"If you don't already have a Zeno account, first create an account on [Zeno Hub](https://hub.zenoml.com).\n",
"After logging in to Zeno Hub, generate your API key by clicking on your profile at the bottom left to navigate to your account page.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%env ZENO_API_KEY=YOUR_API_KEY"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualize Eval Results\n",
"\n",
"You can now use the `zeno_visualize` script to upload the results to Zeno.\n",
"\n",
"This will use all subfolders in `data_path` as different models and upload all tasks within these model folders to Zeno. If you run the eval harness on multiple tasks, the `project_name` will be used as a prefix and one project will be created per task.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!python ../scripts/zeno_visualize.py --data_path output --project_name \"Zeno Upload Test\""
]
}
],
"metadata": {
"kernelspec": {
"display_name": "zeno_projects",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
import argparse
import json
import logging
import os import os
import re import re
import sys import sys
import json
import logging
import argparse
import numpy as np
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
import numpy as np
from lm_eval import evaluator, utils from lm_eval import evaluator, utils
from lm_eval.tasks import initialize_tasks, include_path
from lm_eval.api.registry import ALL_TASKS from lm_eval.api.registry import ALL_TASKS
from lm_eval.tasks import include_path, initialize_tasks
from lm_eval.utils import make_table
def _handle_non_serializable(o): def _handle_non_serializable(o):
...@@ -25,73 +26,93 @@ def _handle_non_serializable(o): ...@@ -25,73 +26,93 @@ def _handle_non_serializable(o):
def parse_eval_args() -> argparse.Namespace: def parse_eval_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model", default="hf", help="Name of model e.g. `hf`") parser.add_argument("--model", "-m", default="hf", help="Name of model e.g. `hf`")
parser.add_argument( parser.add_argument(
"--tasks", "--tasks",
"-t",
default=None, default=None,
metavar="task1,task2",
help="To get full list of tasks, use the command lm-eval --tasks list", help="To get full list of tasks, use the command lm-eval --tasks list",
) )
parser.add_argument( parser.add_argument(
"--model_args", "--model_args",
"-a",
default="", default="",
help="String arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`", help="Comma separated string arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
) )
parser.add_argument( parser.add_argument(
"--num_fewshot", "--num_fewshot",
"-f",
type=int, type=int,
default=None, default=None,
metavar="N",
help="Number of examples in few-shot context", help="Number of examples in few-shot context",
) )
parser.add_argument("--batch_size", type=str, default=1) parser.add_argument(
"--batch_size",
"-b",
type=str,
default=1,
metavar="auto|auto:N|N",
help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
)
parser.add_argument( parser.add_argument(
"--max_batch_size", "--max_batch_size",
type=int, type=int,
default=None, default=None,
help="Maximal batch size to try with --batch_size auto", metavar="N",
help="Maximal batch size to try with --batch_size auto.",
) )
parser.add_argument( parser.add_argument(
"--device", "--device",
type=str, type=str,
default=None, default=None,
help="Device to use (e.g. cuda, cuda:0, cpu)", help="Device to use (e.g. cuda, cuda:0, cpu).",
) )
parser.add_argument( parser.add_argument(
"--output_path", "--output_path",
"-o",
default=None, default=None,
type=str, type=str,
metavar="= [dir/file.jsonl] [DIR]", metavar="DIR|DIR/file.json",
help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.", help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
) )
parser.add_argument( parser.add_argument(
"--limit", "--limit",
"-L",
type=float, type=float,
default=None, default=None,
metavar="N|0<N<1",
help="Limit the number of examples per task. " help="Limit the number of examples per task. "
"If <1, limit is a percentage of the total number of examples.", "If <1, limit is a percentage of the total number of examples.",
) )
parser.add_argument( parser.add_argument(
"--use_cache", "--use_cache",
"-c",
type=str, type=str,
default=None, default=None,
metavar="DIR",
help="A path to a sqlite db file for caching model responses. `None` if not caching.", help="A path to a sqlite db file for caching model responses. `None` if not caching.",
) )
parser.add_argument("--decontamination_ngrams_path", default=None) # TODO: not used parser.add_argument("--decontamination_ngrams_path", default=None) # TODO: not used
parser.add_argument( parser.add_argument(
"--check_integrity", "--check_integrity",
action="store_true", action="store_true",
help="Whether to run the relevant part of the test suite for the tasks", help="Whether to run the relevant part of the test suite for the tasks.",
) )
parser.add_argument( parser.add_argument(
"--write_out", "--write_out",
"-w",
action="store_true", action="store_true",
default=False, default=False,
help="Prints the prompt for the first few documents", help="Prints the prompt for the first few documents.",
) )
parser.add_argument( parser.add_argument(
"--log_samples", "--log_samples",
"-s",
action="store_true", action="store_true",
default=False, default=False,
help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis", help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
) )
parser.add_argument( parser.add_argument(
"--show_config", "--show_config",
...@@ -103,21 +124,24 @@ def parse_eval_args() -> argparse.Namespace: ...@@ -103,21 +124,24 @@ def parse_eval_args() -> argparse.Namespace:
"--include_path", "--include_path",
type=str, type=str,
default=None, default=None,
metavar="DIR",
help="Additional path to include if there are external tasks to include.", help="Additional path to include if there are external tasks to include.",
) )
parser.add_argument( parser.add_argument(
"--gen_kwargs", "--gen_kwargs",
default="", default=None,
help=( help=(
"String arguments for model generation on greedy_until tasks," "String arguments for model generation on greedy_until tasks,"
" e.g. `temperature=0,top_k=0,top_p=0`" " e.g. `temperature=0,top_k=0,top_p=0`."
), ),
) )
parser.add_argument( parser.add_argument(
"--verbosity", "--verbosity",
type=str, "-v",
type=str.upper,
default="INFO", default="INFO",
help="Log error when tasks are not registered.", metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
help="Controls the reported logging error level. Set to DEBUG when testing + adding new task configurations for comprehensive log output.",
) )
return parser.parse_args() return parser.parse_args()
...@@ -147,7 +171,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -147,7 +171,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
task_names = ALL_TASKS task_names = ALL_TASKS
elif args.tasks == "list": elif args.tasks == "list":
eval_logger.info( eval_logger.info(
"Available Tasks:\n - {}".format(f"\n - ".join(sorted(ALL_TASKS))) "Available Tasks:\n - {}".format("\n - ".join(sorted(ALL_TASKS)))
) )
sys.exit() sys.exit()
else: else:
...@@ -179,7 +203,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -179,7 +203,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks", f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
) )
raise ValueError( raise ValueError(
f"Tasks {missing} were not found. Try `lm-eval --tasks list` for list of available tasks." f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks, or '--verbosity DEBUG' to troubleshoot task registration issues."
) )
if args.output_path: if args.output_path:
...@@ -224,7 +248,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -224,7 +248,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if results is not None: if results is not None:
if args.log_samples: if args.log_samples:
samples = results.pop("samples") samples = results.pop("samples")
dumped = json.dumps(results, indent=2, default=_handle_non_serializable) dumped = json.dumps(
results, indent=2, default=_handle_non_serializable, ensure_ascii=False
)
if args.show_config: if args.show_config:
print(dumped) print(dumped)
...@@ -240,17 +266,20 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -240,17 +266,20 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
) )
filename = path.joinpath(f"{output_name}.jsonl") filename = path.joinpath(f"{output_name}.jsonl")
samples_dumped = json.dumps( samples_dumped = json.dumps(
samples[task_name], indent=2, default=_handle_non_serializable samples[task_name],
indent=2,
default=_handle_non_serializable,
ensure_ascii=False,
) )
filename.open("w").write(samples_dumped) filename.write_text(samples_dumped, encoding="utf-8")
print( print(
f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}" f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
) )
print(evaluator.make_table(results)) print(make_table(results))
if "groups" in results: if "groups" in results:
print(evaluator.make_table(results, "groups")) print(make_table(results, "groups"))
if __name__ == "__main__": if __name__ == "__main__":
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List
from lm_eval.api.instance import Instance
from datasets import Dataset from datasets import Dataset
from lm_eval.api.instance import Instance
class Filter: class Filter:
""" """
...@@ -42,7 +43,6 @@ class FilterEnsemble: ...@@ -42,7 +43,6 @@ class FilterEnsemble:
filters: List[Filter] filters: List[Filter]
def apply(self, instances: List[Instance], docs: List[Dataset]) -> None: def apply(self, instances: List[Instance], docs: List[Dataset]) -> None:
resps = [ resps = [
inst.resps for inst in instances inst.resps for inst in instances
] # operate just on the model responses ] # operate just on the model responses
......
import logging
import math import math
import random
from collections.abc import Iterable from collections.abc import Iterable
from collections import defaultdict from collections import defaultdict
import evaluate
import numpy as np import numpy as np
import sacrebleu import sacrebleu
import sklearn.metrics import sklearn.metrics
import random
import evaluate
from lm_eval.api.registry import register_metric, register_aggregation from lm_eval.api.registry import register_aggregation, register_metric
import logging
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
# Register Aggregations First # Register Aggregations First
@register_aggregation("mean") @register_aggregation("mean")
def mean(arr): def mean(arr):
......
import abc import abc
import hashlib
import json
import logging
import os import os
from typing import List, Optional, Tuple, Type, TypeVar
import torch
from typing import Union, List, Tuple, Optional, Type, TypeVar
from sqlitedict import SqliteDict from sqlitedict import SqliteDict
import json
import hashlib
from tqdm import tqdm from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
import logging
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
......
import os import logging
import evaluate import evaluate
from lm_eval.api.model import LM from lm_eval.api.model import LM
import logging
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
...@@ -91,7 +92,6 @@ DEFAULT_METRIC_REGISTRY = { ...@@ -91,7 +92,6 @@ DEFAULT_METRIC_REGISTRY = {
def register_metric(**args): def register_metric(**args):
# TODO: do we want to enforce a certain interface to registered metrics? # TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn): def decorate(fn):
assert "metric" in args assert "metric" in args
name = args["metric"] name = args["metric"]
...@@ -100,7 +100,6 @@ def register_metric(**args): ...@@ -100,7 +100,6 @@ def register_metric(**args):
("higher_is_better", HIGHER_IS_BETTER_REGISTRY), ("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
("aggregation", METRIC_AGGREGATION_REGISTRY), ("aggregation", METRIC_AGGREGATION_REGISTRY),
]: ]:
if key in args: if key in args:
value = args[key] value = args[key]
assert ( assert (
...@@ -120,7 +119,6 @@ def register_metric(**args): ...@@ -120,7 +119,6 @@ def register_metric(**args):
def get_metric(name, hf_evaluate_metric=False): def get_metric(name, hf_evaluate_metric=False):
if not hf_evaluate_metric: if not hf_evaluate_metric:
if name in METRIC_REGISTRY: if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name] return METRIC_REGISTRY[name]
...@@ -151,7 +149,6 @@ def register_aggregation(name): ...@@ -151,7 +149,6 @@ def register_aggregation(name):
def get_aggregation(name): def get_aggregation(name):
try: try:
return AGGREGATION_REGISTRY[name] return AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
...@@ -161,7 +158,6 @@ def get_aggregation(name): ...@@ -161,7 +158,6 @@ def get_aggregation(name):
def get_metric_aggregation(name): def get_metric_aggregation(name):
try: try:
return METRIC_AGGREGATION_REGISTRY[name] return METRIC_AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
......
...@@ -40,18 +40,18 @@ class ContextSampler: ...@@ -40,18 +40,18 @@ class ContextSampler:
self.doc_to_text(doc) self.doc_to_text(doc)
if ( if (
self.config.doc_to_choice is None self.config.doc_to_choice is None
or type(self.doc_to_text(doc)) is str or isinstance(self.doc_to_text(doc), str)
) )
else self.doc_to_choice(doc)[self.doc_to_text(doc)] else self.doc_to_choice(doc)[self.doc_to_text(doc)]
) )
+ self.target_delimiter + self.target_delimiter
+ ( + (
str(self.doc_to_target(doc)[0]) str(self.doc_to_target(doc)[0])
if type(self.doc_to_target(doc)) is list if isinstance(self.doc_to_target(doc), list)
else self.doc_to_target(doc) else self.doc_to_target(doc)
if ( if (
self.config.doc_to_choice is None self.config.doc_to_choice is None
or type(self.doc_to_target(doc)) is str or isinstance(self.doc_to_target(doc), str)
) )
else str(self.doc_to_choice(doc)[self.doc_to_target(doc)]) else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
) )
...@@ -77,8 +77,8 @@ class FirstNSampler(ContextSampler): ...@@ -77,8 +77,8 @@ class FirstNSampler(ContextSampler):
Draw the first `n` samples in order from the specified split. Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU. Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
""" """
assert n <= len( assert (
self.docs n <= len(self.docs)
), f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available." ), f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
return self.docs[:n] return self.docs[:n]
......
import abc import abc
from dataclasses import dataclass, field, asdict
import os
import re
import ast import ast
import yaml
import logging import logging
import evaluate
import random import random
import itertools import re
import functools from collections.abc import Callable
from tqdm import tqdm from dataclasses import asdict, dataclass
from typing import Any, List, Literal, Tuple, Union
import datasets import datasets
import numpy as np import numpy as np
from typing import Union, List, Any, Tuple, Literal
from collections.abc import Callable
from lm_eval import utils from lm_eval import utils
from lm_eval.api import samplers from lm_eval.api import samplers
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.filter import FilterEnsemble
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble
from lm_eval.api.metrics import ( from lm_eval.api.metrics import (
bits_per_byte,
mean, mean,
weighted_perplexity, weighted_perplexity,
bits_per_byte,
metric_max_over_ground_truths,
) )
from lm_eval.api.registry import ( from lm_eval.api.registry import (
get_metric, AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
get_aggregation, get_aggregation,
get_metric,
get_metric_aggregation, get_metric_aggregation,
is_higher_better, is_higher_better,
DEFAULT_METRIC_REGISTRY,
OUTPUT_TYPE_REGISTRY,
AGGREGATION_REGISTRY,
) )
from lm_eval.filters import build_filter_ensemble
from lm_eval.prompts import get_prompt
ALL_OUTPUT_TYPES = [ ALL_OUTPUT_TYPES = [
"loglikelihood", "loglikelihood",
...@@ -97,12 +86,6 @@ class TaskConfig(dict): ...@@ -97,12 +86,6 @@ class TaskConfig(dict):
] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks ] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.dataset_path and os.path.exists(os.path.dirname(self.dataset_path)):
import inspect
from importlib import import_module
self.dataset_path = inspect.getfile(import_module(self.dataset_path))
if self.generation_kwargs is not None: if self.generation_kwargs is not None:
if self.output_type != "generate_until": if self.output_type != "generate_until":
eval_logger.warning( eval_logger.warning(
...@@ -349,9 +332,7 @@ class Task(abc.ABC): ...@@ -349,9 +332,7 @@ class Task(abc.ABC):
elif self.has_validation_docs(): elif self.has_validation_docs():
docs = self.validation_docs() docs = self.validation_docs()
else: else:
assert ( assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
False
), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
eval_logger.info(f"Building contexts for task on rank {rank}...") eval_logger.info(f"Building contexts for task on rank {rank}...")
...@@ -546,6 +527,10 @@ class ConfigurableTask(Task): ...@@ -546,6 +527,10 @@ class ConfigurableTask(Task):
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg" "Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
) )
if isinstance(self.config.metadata, dict):
if "version" in self.config.metadata:
self.VERSION = self.config.metadata["version"]
if self.config.output_type is not None: if self.config.output_type is not None:
assert self.config.output_type in ALL_OUTPUT_TYPES assert self.config.output_type in ALL_OUTPUT_TYPES
self.OUTPUT_TYPE = self.config.output_type self.OUTPUT_TYPE = self.config.output_type
...@@ -603,9 +588,9 @@ class ConfigurableTask(Task): ...@@ -603,9 +588,9 @@ class ConfigurableTask(Task):
if "aggregation" in metric_config: if "aggregation" in metric_config:
agg_name = metric_config["aggregation"] agg_name = metric_config["aggregation"]
if type(agg_name) == str: if isinstance(agg_name, str):
self._aggregation_list[metric_name] = get_aggregation(agg_name) self._aggregation_list[metric_name] = get_aggregation(agg_name)
elif callable(agg_name): elif callable(agg_name): # noqa: E721
self._aggregation_list[metric_name] = metric_config[ self._aggregation_list[metric_name] = metric_config[
"aggregation" "aggregation"
] ]
...@@ -672,9 +657,7 @@ class ConfigurableTask(Task): ...@@ -672,9 +657,7 @@ class ConfigurableTask(Task):
elif self.has_validation_docs(): elif self.has_validation_docs():
self.task_docs = self.validation_docs() self.task_docs = self.validation_docs()
else: else:
assert ( assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
False
), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
# Test One Doc # Test One Doc
self.features = list(self.task_docs.features.keys()) self.features = list(self.task_docs.features.keys())
...@@ -686,20 +669,20 @@ class ConfigurableTask(Task): ...@@ -686,20 +669,20 @@ class ConfigurableTask(Task):
if self.config.doc_to_choice is not None: if self.config.doc_to_choice is not None:
test_choice = self.doc_to_choice(test_doc) test_choice = self.doc_to_choice(test_doc)
if type(test_choice) is not list: if not isinstance(test_choice, list):
eval_logger.error("doc_to_choice must return list") eval_logger.error("doc_to_choice must return list")
else: else:
num_choice = len(test_choice) num_choice = len(test_choice)
if type(test_text) is int: if isinstance(test_text, int):
self.multiple_input = num_choice self.multiple_input = num_choice
else: else:
test_choice = None test_choice = None
if type(test_target) is list: if isinstance(test_target, list):
self.multiple_target = len(test_target) self.multiple_target = len(test_target)
else: else:
if (type(test_target) is int) and (test_choice is not None): if (isinstance(test_target, int)) and (test_choice is not None):
test_target = test_choice[test_target] test_target = test_choice[test_target]
else: else:
test_target = str(test_target) test_target = str(test_target)
...@@ -719,11 +702,11 @@ class ConfigurableTask(Task): ...@@ -719,11 +702,11 @@ class ConfigurableTask(Task):
) )
if delimiter_has_whitespace and choice_has_whitespace: if delimiter_has_whitespace and choice_has_whitespace:
eval_logger.warning( eval_logger.debug(
f'Both target_delimiter and target choice: "{choice}" have whitespace' f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace'
) )
elif (not delimiter_has_whitespace) and (not choice_has_whitespace): elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
eval_logger.warning( eval_logger.debug(
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace' f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
) )
...@@ -776,6 +759,8 @@ class ConfigurableTask(Task): ...@@ -776,6 +759,8 @@ class ConfigurableTask(Task):
def fewshot_docs(self): def fewshot_docs(self):
if self.config.fewshot_split is not None: if self.config.fewshot_split is not None:
if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.fewshot_split])
return self.dataset[self.config.fewshot_split] return self.dataset[self.config.fewshot_split]
else: else:
if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0): if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
...@@ -808,16 +793,19 @@ class ConfigurableTask(Task): ...@@ -808,16 +793,19 @@ class ConfigurableTask(Task):
) )
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
if type(example) == str: if self.multiple_input:
return labeled_examples + example return labeled_examples
elif type(example) == list: else:
return [labeled_examples + ex for ex in example] if isinstance(example, str):
elif type(example) == int: return labeled_examples + example
if self.config.doc_to_choice is not None: elif isinstance(example, list):
choices = self.doc_to_choice(doc) return [labeled_examples + ex for ex in example]
return labeled_examples + choices[example] elif isinstance(example, int):
else: if self.config.doc_to_choice is not None:
return labeled_examples + str(example) choices = self.doc_to_choice(doc)
return labeled_examples + choices[example]
else:
return labeled_examples + str(example)
def apply_filters(self): def apply_filters(self):
if hasattr(self, "_filters"): if hasattr(self, "_filters"):
...@@ -864,9 +852,9 @@ class ConfigurableTask(Task): ...@@ -864,9 +852,9 @@ class ConfigurableTask(Task):
else: else:
doc_to_text = self.config.doc_to_text doc_to_text = self.config.doc_to_text
if type(doc_to_text) == int: if isinstance(doc_to_text, int):
return doc_to_text return doc_to_text
elif type(doc_to_text) == str: elif isinstance(doc_to_text, str):
if doc_to_text in self.features: if doc_to_text in self.features:
# if self.config.doc_to_choice is not None: # if self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_text]] # return self.doc_to_choice(doc)[doc[doc_to_text]]
...@@ -898,9 +886,9 @@ class ConfigurableTask(Task): ...@@ -898,9 +886,9 @@ class ConfigurableTask(Task):
else: else:
doc_to_target = self.config.doc_to_target doc_to_target = self.config.doc_to_target
if type(doc_to_target) == int: if isinstance(doc_to_target, int):
return doc_to_target return doc_to_target
elif type(doc_to_target) == str: elif isinstance(doc_to_target, str):
if doc_to_target in self.features: if doc_to_target in self.features:
# if self.config.doc_to_choice is not None: # if self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_target]] # return self.doc_to_choice(doc)[doc[doc_to_target]]
...@@ -921,7 +909,7 @@ class ConfigurableTask(Task): ...@@ -921,7 +909,7 @@ class ConfigurableTask(Task):
return target_string return target_string
else: else:
return target_string return target_string
elif type(doc_to_target) == list: elif isinstance(doc_to_target, list):
return doc_to_target return doc_to_target
elif callable(doc_to_target): elif callable(doc_to_target):
return doc_to_target(doc) return doc_to_target(doc)
...@@ -944,14 +932,14 @@ class ConfigurableTask(Task): ...@@ -944,14 +932,14 @@ class ConfigurableTask(Task):
else: else:
doc_to_choice = self.config.doc_to_choice doc_to_choice = self.config.doc_to_choice
if type(doc_to_choice) == str: if isinstance(doc_to_choice, str):
if doc_to_choice in self.features: if doc_to_choice in self.features:
return doc[doc_to_choice] return doc[doc_to_choice]
else: else:
return ast.literal_eval(utils.apply_template(doc_to_choice, doc)) return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
elif type(doc_to_choice) == list: elif isinstance(doc_to_choice, list):
return doc_to_choice return doc_to_choice
elif type(doc_to_choice) == dict: elif isinstance(doc_to_choice, dict):
return list(doc_to_choice.values()) return list(doc_to_choice.values())
elif callable(doc_to_choice): elif callable(doc_to_choice):
return doc_to_choice(doc) return doc_to_choice(doc)
...@@ -973,7 +961,9 @@ class ConfigurableTask(Task): ...@@ -973,7 +961,9 @@ class ConfigurableTask(Task):
if self.multiple_input: if self.multiple_input:
# If there are multiple inputs, choices are placed in the ctx # If there are multiple inputs, choices are placed in the ctx
cont = self.doc_to_target(doc) cont = self.doc_to_target(doc)
arguments = [(ctx, f"{target_delimiter}{cont}") for ctx in choices] arguments = [
(ctx + choice, f"{target_delimiter}{cont}") for choice in choices
]
else: else:
# Otherwise they are placed in the continuation # Otherwise they are placed in the continuation
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices] arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
...@@ -1085,14 +1075,14 @@ class ConfigurableTask(Task): ...@@ -1085,14 +1075,14 @@ class ConfigurableTask(Task):
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
gold_index_error = False gold_index_error = False
if type(gold) is list: if isinstance(gold, list):
gold = [i if i < len(choices) else -100 for i in gold] gold = [i if i < len(choices) else -100 for i in gold]
if -100 in gold: if -100 in gold:
gold_index_error = True gold_index_error = True
else: else:
if type(gold) is int: if isinstance(gold, int):
gold = gold if gold < len(choices) else -100 gold = gold if gold < len(choices) else -100
elif type(gold) is str: elif isinstance(gold, str):
gold = choices.index(gold) if gold in choices else -100 gold = choices.index(gold) if gold in choices else -100
if gold == -100: if gold == -100:
...@@ -1164,27 +1154,36 @@ class ConfigurableTask(Task): ...@@ -1164,27 +1154,36 @@ class ConfigurableTask(Task):
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer # sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
# print(gold) # print(gold)
gold = [gold] gold = [gold]
for gold_option in gold: if metric == "exact_match":
try: result = [result for _ in range(len(gold))]
result_score = self._metric_fn_list[metric]( scores = self._metric_fn_list[metric](
references=[gold_option], references=gold,
predictions=[result], predictions=result,
**self._metric_fn_kwargs[metric], **self._metric_fn_kwargs[metric],
) )[metric]
except ( result_score = 1.0 if scores > 0.0 else 0.0
TypeError
): # TODO: this is hacky and I don't want to do it
result_score = self._metric_fn_list[metric](
[gold_option, result]
)
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric]
scores.append(result_score)
if any(scores):
result_score = 1.0
else: else:
result_score = 0.0 for gold_option in gold:
try:
result_score = self._metric_fn_list[metric](
references=[gold_option],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except (
TypeError
): # TODO: this is hacky and I don't want to do it
result_score = self._metric_fn_list[metric](
[gold_option, result]
)
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric]
scores.append(result_score)
if any(scores):
result_score = 1.0
else:
result_score = 0.0
else: else:
try: try:
result_score = self._metric_fn_list[metric]( result_score = self._metric_fn_list[metric](
...@@ -1192,9 +1191,7 @@ class ConfigurableTask(Task): ...@@ -1192,9 +1191,7 @@ class ConfigurableTask(Task):
predictions=[result], predictions=[result],
**self._metric_fn_kwargs[metric], **self._metric_fn_kwargs[metric],
) )
except ( except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
TypeError
): # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result]) result_score = self._metric_fn_list[metric]([gold, result])
if isinstance(result_score, dict): if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict. # TODO: this handles the case where HF evaluate returns a dict.
......
import datetime
import io
import json
import mmap
import os import os
from pathlib import Path
from typing import Any from typing import Any
import zstandard
import json
import jsonlines import jsonlines
import io
import datetime
import mmap
import tqdm import tqdm
from pathlib import Path import zstandard
def json_serial(obj: Any) -> str: def json_serial(obj: Any) -> str:
......
import time import collections
import random
import pickle
import json
import glob import glob
import json
import os import os
import collections import pickle
import random
import time
from .janitor import Janitor, word_ngrams
from .archiver import ZStdTextReader from .archiver import ZStdTextReader
from .janitor import Janitor, word_ngrams
# Was used for testing the evaluator decoupled from the full logic below # Was used for testing the evaluator decoupled from the full logic below
...@@ -109,7 +109,7 @@ def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> d ...@@ -109,7 +109,7 @@ def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> d
print(f"Merging lookups took {elapsed:0.5f} seconds.") print(f"Merging lookups took {elapsed:0.5f} seconds.")
print(f"{ngrams_n_size} grams files found in {ngrams_path}:") print(f"{ngrams_n_size} grams files found in {ngrams_path}:")
files = glob.glob(os.path.join(ngrams_path, f"*.sorted.zst")) files = glob.glob(os.path.join(ngrams_path, "*.sorted.zst"))
print(files) print(files)
for file in files: for file in files:
...@@ -135,11 +135,7 @@ def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> d ...@@ -135,11 +135,7 @@ def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> d
matching_unique += 1 matching_unique += 1
for task_name, task_set, doc_ids in merged_lookup[ngram]: for task_name, task_set, doc_ids in merged_lookup[ngram]:
task_doc_set = duplicates[(task_name, task_set)] task_doc_set = duplicates[(task_name, task_set)]
for ( for doc_id in doc_ids: # Record contamination across all relevant task/set combos
doc_id
) in (
doc_ids
): # Record contamination across all relevant task/set combos
task_doc_set.add(doc_id) task_doc_set.add(doc_id)
del merged_lookup[ngram] # No point matching again del merged_lookup[ngram] # No point matching again
else: else:
......
import pickle
import re import re
import string import string
import pickle
import traceback import traceback
from pprint import pprint from typing import Iterator, List, Sequence, Tuple, TypeVar
from typing import Iterator, Sequence, TypeVar, List, Tuple
# This is a cpp module. Compile janitor_util.cpp with: # This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup # c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment