Unverified Commit edd7dde3 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge branch 'big-refactor' into haileyschoelkopf-patch-2

parents fb436108 365fcda9
...@@ -3,10 +3,10 @@ name: Tasks Modified ...@@ -3,10 +3,10 @@ name: Tasks Modified
on: on:
push: push:
branches: branches:
- big-refactor - 'big-refactor*'
pull_request: pull_request:
branches: branches:
- big-refactor - 'big-refactor*'
workflow_dispatch: workflow_dispatch:
# comment/edit out the above to stop/change the triggers # comment/edit out the above to stop/change the triggers
jobs: jobs:
...@@ -18,7 +18,7 @@ jobs: ...@@ -18,7 +18,7 @@ jobs:
- name: checkout - name: checkout
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
fetch-depth: 0 # OR "2" -> To retrieve the preceding commit. fetch-depth: 2 # OR "2" -> To retrieve the preceding commit.
# Uses the tj-actions/changed-files@v37 action to check for changes. # Uses the tj-actions/changed-files@v37 action to check for changes.
# Outputs provided here: https://github.com/tj-actions/changed-files#outputs # Outputs provided here: https://github.com/tj-actions/changed-files#outputs
...@@ -51,6 +51,7 @@ jobs: ...@@ -51,6 +51,7 @@ jobs:
with: with:
python-version: 3.9 python-version: 3.9
cache: 'pip' cache: 'pip'
cache-dependency-path: setup.py
- name: Install dependencies - name: Install dependencies
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true' if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
run: | run: |
...@@ -62,10 +63,10 @@ jobs: ...@@ -62,10 +63,10 @@ jobs:
- name: Test with pytest - name: Test with pytest
# if new tasks are added, run tests on them # if new tasks are added, run tests on them
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' if: steps.changed-tasks.outputs.tasks_any_modified == 'true'
run: python -m pytest tests/extra/test_new_tasks.py -s -vv -n=auto run: python -m pytest tests/test_tasks.py -s -vv -n=auto
# if api is modified, run tests on it # if api is modified, run tests on it
- name: Test more tasks with pytest - name: Test more tasks with pytest
env: env:
API: true API: true
if: steps.changed-tasks.outputs.api_any_modified == 'true' if: steps.changed-tasks.outputs.api_any_modified == 'true'
run: python -m pytest tests/extra/test_new_tasks.py -s -vv -n=auto run: python -m pytest tests/test_tasks.py -s -vv -n=auto
...@@ -26,7 +26,8 @@ jobs: ...@@ -26,7 +26,8 @@ jobs:
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: 3.9 python-version: 3.9
cache: 'pip' cache: pip
cache-dependency-path: setup.py
- name: Install dependencies - name: Install dependencies
run: pip install -e '.[linting,testing]' --extra-index-url https://download.pytorch.org/whl/cpu run: pip install -e '.[linting,testing]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Pre-Commit - name: Pre-Commit
...@@ -46,22 +47,32 @@ jobs: ...@@ -46,22 +47,32 @@ jobs:
testcpu: testcpu:
name: CPU Tests name: CPU Tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
timeout-minutes: 20 strategy:
matrix:
python-version: [ "3.9", "3.10", "3.11" ]
timeout-minutes: 30
steps: steps:
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Set up Python 3.9 - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: 3.9 python-version: ${{ matrix.python-version }}
cache: 'pip' cache: pip
cache-dependency-path: setup.py
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies # Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest - name: Test with pytest
run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/tests_master --ignore=tests/extra run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/tests_master --ignore=tests/extra
- name: Archive artifacts
uses: actions/upload-artifact@v3
with:
name: output_results
path: |
test_logs/*
...@@ -43,3 +43,9 @@ repos: ...@@ -43,3 +43,9 @@ repos:
.*\.json|ignore.txt .*\.json|ignore.txt
)$ )$
args: [--check-filenames, --check-hidden, --ignore-words=ignore.txt] args: [--check-filenames, --check-hidden, --ignore-words=ignore.txt]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.1
hooks:
- id: mypy
additional_dependencies: [".[sentencepiece,multilingual,promptsource,gptq]", "types-PyYAML", "types-requests"]
exclude: ^tests/.*$
...@@ -20,7 +20,7 @@ This project provides a unified framework to test generative language models on ...@@ -20,7 +20,7 @@ This project provides a unified framework to test generative language models on
Features: Features:
- Many tasks implemented, 200+ tasks [implemented in the old framework](https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md) which require porting to the new setup as described in [the new task guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/lm_eval/docs/new_task_guide.md). - Many tasks implemented, 200+ tasks [implemented in the old framework](https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md) which require porting to the new setup as described in [the new task guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/new_task_guide.md).
- Support for models loaded via [transformers](https://github.com/huggingface/transformers/) (including quantization via [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ)), [GPT-NeoX](https://github.com/EleutherAI/gpt-neox), and [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed/), with a flexible tokenization-agnostic interface. - Support for models loaded via [transformers](https://github.com/huggingface/transformers/) (including quantization via [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ)), [GPT-NeoX](https://github.com/EleutherAI/gpt-neox), and [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed/), with a flexible tokenization-agnostic interface.
- Support for commercial APIs including [OpenAI](https://openai.com), [goose.ai](https://goose.ai), and [TextSynth](https://textsynth.com/). - Support for commercial APIs including [OpenAI](https://openai.com), [goose.ai](https://goose.ai), and [TextSynth](https://textsynth.com/).
- Support for evaluation on adapters (e.g. LoRa) supported in [HuggingFace's PEFT library](https://github.com/huggingface/peft). - Support for evaluation on adapters (e.g. LoRa) supported in [HuggingFace's PEFT library](https://github.com/huggingface/peft).
...@@ -116,8 +116,10 @@ accelerate launch main.py \ ...@@ -116,8 +116,10 @@ accelerate launch main.py \
This will perform *data-parallel evaluation*: that is, placing a **single full copy** of your model onto each available GPU and *splitting batches across GPUs* to evaluate on K GPUs K times faster than on one. This will perform *data-parallel evaluation*: that is, placing a **single full copy** of your model onto each available GPU and *splitting batches across GPUs* to evaluate on K GPUs K times faster than on one.
However, if your model *is too large to be run on a single one of your GPUs*, then we provide an alternative method to run these large models: use of the `parallelize` argument. If your model is *is too large to be run on a single one of your GPUs* then you can use `accelerate` with Fully Sharded Data Parallel (FSDP) that splits the weights of the model across your data parallel ranks. To enable this, ensure you select `YES` when asked ```Do you want to use FullyShardedDataParallel?``` when running `accelerate config`. To enable memory-efficient loading, select `YES` when asked `Do you want each individually wrapped FSDP unit to broadcast module parameters from rank 0 at the start?`. This will ensure only the rank 0 process loads the model and then broadcasts the parameters to the other ranks instead of having each rank load all parameters which can lead to large RAM usage spikes around the start of the script that may cause errors.
We also provide an second method to run these large models: use of the `parallelize` argument.
``` ```
python main.py \ python main.py \
--model hf \ --model hf \
...@@ -132,7 +134,7 @@ To pass even more advanced keyword arguments to `accelerate`, we allow for the f ...@@ -132,7 +134,7 @@ To pass even more advanced keyword arguments to `accelerate`, we allow for the f
- `max_cpu_memory`: the max amount of CPU memory to use when offloading the model weights to RAM. - `max_cpu_memory`: the max amount of CPU memory to use when offloading the model weights to RAM.
- `offload_folder`: a folder where model weights will be offloaded to disk if needed. - `offload_folder`: a folder where model weights will be offloaded to disk if needed.
Using this setting helps for massive models like BLOOM which require, or to avoid exceeding your total system RAM (by default, with `accelerate launch` one copy of the model for each GPU is initialized in RAM before moving it to GPU, resulting in large RAM usage spikes around the start of the script that may cause errors such as `Killed`.) However, it naively splits models across GPUs, resulting in only a single GPU performing work at any point in time, and so is much slower than launching with `accelerate launch`, possibly by a factor of the total # of GPUs. Note that this method naively splits models across GPUs, resulting in only a single GPU performing work at any point in time, and so is much slower than launching with `accelerate launch`, possibly by a factor of the total # of GPUs.
**Note that this option requires launching evaluation via `python main.py` rather than `accelerate launch main.py`.** **Note that this option requires launching evaluation via `python main.py` rather than `accelerate launch main.py`.**
......
...@@ -69,6 +69,8 @@ touch lm_eval/tasks/<dataset_name>/utils.py ...@@ -69,6 +69,8 @@ touch lm_eval/tasks/<dataset_name>/utils.py
``` ```
Now, in `utils.py` we'll write a function to process each split of our dataset: Now, in `utils.py` we'll write a function to process each split of our dataset:
TODO: Change the example to one that's in the tasks/
```python ```python
def process_docs(dataset: datasets.Dataset): def process_docs(dataset: datasets.Dataset):
def _helper(doc): def _helper(doc):
...@@ -86,40 +88,53 @@ Now, in our YAML config file we'll use the `!function` constructor, and tell the ...@@ -86,40 +88,53 @@ Now, in our YAML config file we'll use the `!function` constructor, and tell the
process_docs: !function utils.process_docs process_docs: !function utils.process_docs
``` ```
## Writing a Prompt Template
### Writing a prompt with Jinja 2
The next thing we need to do is decide what format to use when presenting the data to the LM. This is our **prompt**, where we'll define both an input and output format. The next thing we need to do is decide what format to use when presenting the data to the LM. This is our **prompt**, where we'll define both an input and output format.
We support the [Jinja 2](https://jinja.palletsprojects.com/en/3.1.x/) templating language for writing prompts. In practice, this means you can take your dataset's columns and do many basic string manipulations to place each document into prompted format. To write a prompt, users will use `doc_to_text`, `doc_to_target`, and `doc_to_choice` (Optional when certain conditions are met).
`doc_to_text` defines the input string a model will be given while `doc_to_target` and `doc_to_choice` will be used to generate the target text. `doc_to_target` can be either a text string that refers to the target string or an integer that refers to the index of the correct label. When it is set as an index, `doc_to_choice` must be also be set with the appropriate list of possible choice strings.
To write a prompt, users are required to write two or three YAML fields in Jinja as strings: ### Basic prompts
If a dataset is straightforward enough, users can enter the feature name directly. This assumes that no preprocessing is required. For example in [Swag](https://github.com/EleutherAI/lm-evaluation-harness/blob/1710b42d52d0f327cb0eb3cb1bfbbeca992836ca/lm_eval/tasks/swag/swag.yaml#L10-L11), `doc_to_text` and `doc_to_target` given the name of one of the feature each.
```yaml ```yaml
doc_to_text: doc_to_text: startphrase
doc_to_target: doc_to_target: label
doc_to_choice:
``` ```
Suppose our dataset has a `"question"` field, and an `"answer"` field, which are both strings. We want the model to see, if given a `document` object that is a row of our dataset: Hard-coding is also possible as is the case in [SciQ](https://github.com/EleutherAI/lm-evaluation-harness/blob/1710b42d52d0f327cb0eb3cb1bfbbeca992836ca/lm_eval/tasks/sciq/sciq.yaml#L11).
```yaml
doc_to_target: 3
``` ```
Question: {document[question]} `doc_to_choice` can be directly given a list of text as option (See [Toxigen](https://github.com/EleutherAI/lm-evaluation-harness/blob/1710b42d52d0f327cb0eb3cb1bfbbeca992836ca/lm_eval/tasks/toxigen/toxigen.yaml#L11))
```yaml
doc_to_choice: ['No', 'Yes']
```
### Writing a prompt with Jinja 2
We support the [Jinja 2](https://jinja.palletsprojects.com/en/3.1.x/) templating language for writing prompts. In practice, this means you can take your dataset's columns and do many basic string manipulations to place each document into prompted format.
Take for example `super_glue/boolq`, as input, we'd like to use the features `passage` and `question` and string them together so that for a a sample line `doc`, the model sees something the format of:
```
doc["passage"]
Question: doc["question"]?
Answer: Answer:
``` ```
We do this by writing We do this by [writing](https://github.com/EleutherAI/lm-evaluation-harness/blob/1710b42d52d0f327cb0eb3cb1bfbbeca992836ca/lm_eval/tasks/super_glue/boolq/default.yaml#L9C1-L9C61)
```yaml ```yaml
doc_to_text: "Question: {{question}}\nAnswer:" doc_to_text: "{{passage}}\nQuestion: {{question}}?\nAnswer:"
``` ```
Such that {{question}} will be replaced by `doc["question"]` when rendering the prompt template. Such that `{{passage}}` will be replaced by `doc["passage"]` and `{{question}}` with `doc["question"]` when rendering the prompt template.
Our intended output is for the model to predict a single whitespace, and then the answer to the question. We do this via: Our intended output is for the model to predict a single whitespace, and then the answer to the question. We do this via:
```yaml ```yaml
doc_to_target: "{{answer}}" doc_to_target: "{{answer}}"
gold_alias: "{{answer}}"
``` ```
where `doc_to_target` is *the string that will be appended to inputs for each few-shot example*, and `gold_alias` is *what is passed to our metric function as reference or gold answer to score against*. For example, for GSM8k word problems, `doc_to_target` should be the reference text reasoning chain given in the dataset culminating in the answer, and `gold_alias` should be **only the numeric answer** to the word problem that is given at the end of the reasoning chain, and which the evaluated model's answer will be compared against.
**Important**: We always add one whitespace between the input and output, such that the full input-output string is `doc_to_target(doc) + " " + doc_to_text(doc)`. doc_to_text and doc_to_target should not contain trailing right or left whitespace, respectively.
Users can also fill out the optional `template_aliases` YAML field, which is added ahead of both the `doc_to_text` and `doc_to_target` fields. This field should not contain any test, but only Jinja variable definitions (`{% ... %}` clauses). This can be used to perform more involved string manipulations and renamings of dataset columns while the main prompt fields remain easy to parse visually. **Important**: we now add `target_delimiter` between input and target which defaults to " ", such that the full input-output string is `doc_to_target(doc) + target_delimiter + doc_to_text(doc)`. doc_to_text and doc_to_target should not contain trailing right or left whitespace, respectively.
#### Multiple choice format #### Multiple choice format
...@@ -135,7 +150,13 @@ doc_to_choice: "{{[distractor1, distractor2, distractor3, correct_answer]}}" ...@@ -135,7 +150,13 @@ doc_to_choice: "{{[distractor1, distractor2, distractor3, correct_answer]}}"
``` ```
Task implementers are thus able to decide what the answer choices should be for a document, and what prompt format to use. Task implementers are thus able to decide what the answer choices should be for a document, and what prompt format to use.
The label index can also be sourced from a feature directly. For example in `superglue/boolq`, the label index if defined in the feature `label`. We can set `doc_to_target` as simply `label`. The options or verbalizers can be written in a the form of a list `["no", "yes"]` that will correspond to the label index.
```yaml
doc_to_text: "{{passage}}\nQuestion: {{question}}?\nAnswer:"
doc_to_target: label
doc_to_choice: ["no", "yes"]
```
### Using Python Functions for Prompts ### Using Python Functions for Prompts
...@@ -168,6 +189,10 @@ For example, For Super Glue BoolQ, if we want to use the prompt template `GPT-3 ...@@ -168,6 +189,10 @@ For example, For Super Glue BoolQ, if we want to use the prompt template `GPT-3
use_prompt: "promptsource:GPT-3 Style" use_prompt: "promptsource:GPT-3 Style"
``` ```
If you would like to run evaluation on all prompt templates, you can simply call it this way.
```
use_prompt: "promptsource:*"
```
### Setting metrics ### Setting metrics
...@@ -183,11 +208,11 @@ metric_list: ...@@ -183,11 +208,11 @@ metric_list:
- metric: <name of the metric here> - metric: <name of the metric here>
aggregation: <name of the aggregation fn here> aggregation: <name of the aggregation fn here>
higher_is_better: <true or false> higher_is_better: <true or false>
- metric: ... - metric: !function script.function
aggregation: ... aggregation: ...
higher_is_better: ... higher_is_better: ...
``` ```
`aggregation` and `higher_is_better` can optionally be left out to default to the manually-set defaults, if using a natively supported metric. `aggregation` and `higher_is_better` can optionally be left out to default to the manually-set defaults if using a natively supported metric, otherwise it must be defined explicitly (for example, when using a custom metric implemented as a function).
For a full list of natively supported metrics and aggregation functions see `docs/advanced_task_guide.md`. All metrics supported in [HuggingFace Evaluate](https://github.com/huggingface/evaluate/tree/main/metrics) can also be used, and will be loaded if a given metric name is not one natively supported in `lm-eval`. For a full list of natively supported metrics and aggregation functions see `docs/advanced_task_guide.md`. All metrics supported in [HuggingFace Evaluate](https://github.com/huggingface/evaluate/tree/main/metrics) can also be used, and will be loaded if a given metric name is not one natively supported in `lm-eval`.
......
# Advanced Task Configuration # Task Configuration
The `lm-evaluation-harness` is meant to be an extensible and flexible framework within which many different evaluation tasks can be defined. All tasks in the new version of the harness are built around a YAML configuration file format. The `lm-evaluation-harness` is meant to be an extensible and flexible framework within which many different evaluation tasks can be defined. All tasks in the new version of the harness are built around a YAML configuration file format.
...@@ -33,7 +33,6 @@ Prompting / in-context formatting options: ...@@ -33,7 +33,6 @@ Prompting / in-context formatting options:
- **doc_to_text** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate input for the model - **doc_to_text** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate input for the model
- **doc_to_target** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate target output for the model. For multiple choice tasks, this should return an index into - **doc_to_target** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate target output for the model. For multiple choice tasks, this should return an index into
- **doc_to_choice** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `greedy_until` tasks. - **doc_to_choice** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `greedy_until` tasks.
- **gold_alias** (`str`, *optional*, defaults to None) — if provided, used to generate the reference answer that is scored against. Used in cases where `doc_to_target` should be the "target string" format appended to each example's input for a fewshot exemplar, so doc_to_target is used for fewshot examples, but the input to the metric function as `gold` is from `gold_alias`.
- **fewshot_delimiter** (`str`, *optional*, defaults to "\n\n") — String to insert between few-shot examples. - **fewshot_delimiter** (`str`, *optional*, defaults to "\n\n") — String to insert between few-shot examples.
- **target_delimiter** (`str`, *optional*, defaults to `" "`) — String to insert between input and target output for the datapoint being tested. - **target_delimiter** (`str`, *optional*, defaults to `" "`) — String to insert between input and target output for the datapoint being tested.
......
...@@ -4,3 +4,4 @@ nin ...@@ -4,3 +4,4 @@ nin
maka maka
mor mor
te te
ond
from .evaluator import evaluate, simple_evaluate
...@@ -2,6 +2,7 @@ from dataclasses import dataclass ...@@ -2,6 +2,7 @@ from dataclasses import dataclass
from typing import List from typing import List
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from datasets import Dataset
class Filter: class Filter:
...@@ -13,12 +14,12 @@ class Filter: ...@@ -13,12 +14,12 @@ class Filter:
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs) -> None:
""" """
Can define custom behavior here, if an individual instantiation of a Filter class should have state. Can define custom behavior here, if an individual instantiation of a Filter class should have state.
""" """
def apply(self, resps): def apply(self, resps, docs):
""" """
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects. Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g. Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
...@@ -40,14 +41,14 @@ class FilterEnsemble: ...@@ -40,14 +41,14 @@ class FilterEnsemble:
name: str name: str
filters: List[Filter] filters: List[Filter]
def apply(self, instances: List[Instance]): def apply(self, instances: List[Instance], docs: List[Dataset]) -> None:
resps = [ resps = [
inst.resps for inst in instances inst.resps for inst in instances
] # operate just on the model responses ] # operate just on the model responses
for f in self.filters: for f in self.filters:
# apply filters in sequence # apply filters in sequence
resps = f.apply(resps) resps = f.apply(resps, docs)
# add the end results after filtering to filtered_requests of their respective source instances. # add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name. # has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
......
...@@ -19,7 +19,7 @@ class Instance: ...@@ -19,7 +19,7 @@ class Instance:
doc_id: str = None doc_id: str = None
repeats: str = None repeats: str = None
def __post_init__(self): def __post_init__(self) -> None:
# unpack metadata field # unpack metadata field
self.task_name, self.doc_id, self.repeats = self.metadata self.task_name, self.doc_id, self.repeats = self.metadata
......
...@@ -56,6 +56,55 @@ def matthews_corrcoef(items): ...@@ -56,6 +56,55 @@ def matthews_corrcoef(items):
return sklearn.metrics.matthews_corrcoef(golds, preds) return sklearn.metrics.matthews_corrcoef(golds, preds)
@register_aggregation("bleu")
def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where
1-gram or unigram would be each token and a bigram comparison would be each
word pair. The comparison is made regardless of word order
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
Paper: https://www.aclweb.org/anthology/P02-1040/
Higher is better
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_bleu(preds, refs).score
@register_aggregation("chrf")
def chrf(items):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
Source: https://github.com/m-popovic/chrF
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
Higher is better # TODO I think
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_chrf(preds, refs).score
@register_aggregation("ter")
def ter(items):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
of the references
Source: http://www.cs.umd.edu/~snover/tercom/
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
Lower is better
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_ter(preds, refs).score
@register_metric( @register_metric(
metric="acc", metric="acc",
higher_is_better=True, higher_is_better=True,
...@@ -160,6 +209,36 @@ def f1_fn(items): # This is a passthrough function ...@@ -160,6 +209,36 @@ def f1_fn(items): # This is a passthrough function
return items return items
@register_metric(
metric="bleu",
higher_is_better=True,
output_type="greedy_until",
aggregation="bleu",
)
def bleu_fn(items): # This is a passthrough function
return items
@register_metric(
metric="chrf",
higher_is_better=True,
output_type="greedy_until",
aggregation="chrf",
)
def chrf_fn(items): # This is a passthrough function
return items
@register_metric(
metric="ter",
higher_is_better=True,
output_type="greedy_until",
aggregation="ter",
)
def ter_fn(items): # This is a passthrough function
return items
@register_metric( @register_metric(
metric="acc_all", metric="acc_all",
higher_is_better=True, higher_is_better=True,
...@@ -217,55 +296,6 @@ def weighted_mean(items): ...@@ -217,55 +296,6 @@ def weighted_mean(items):
return sum(a) / sum(b) return sum(a) / sum(b)
@register_metric(metric="bleu", higher_is_better=True, aggregation="mean")
def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where
1-gram or unigram would be each token and a bigram comparison would be each
word pair. The comparison is made regardless of word order
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
Paper: https://www.aclweb.org/anthology/P02-1040/
Higher is better
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_bleu(preds, refs).score
@register_metric(metric="chrf", higher_is_better=True, aggregation="mean")
def chrf(items):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
Source: https://github.com/m-popovic/chrF
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
Higher is better # TODO I think
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_chrf(preds, refs).score
@register_metric(metric="ter", higher_is_better=True, aggregation="mean")
def ter(items):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
of the references
Source: http://www.cs.umd.edu/~snover/tercom/
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
Lower is better
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_ter(preds, refs).score
def is_non_str_iterable(obj): def is_non_str_iterable(obj):
return isinstance(obj, Iterable) and not isinstance(obj, str) return isinstance(obj, Iterable) and not isinstance(obj, str)
...@@ -302,7 +332,7 @@ def _sacreformat(refs, preds): ...@@ -302,7 +332,7 @@ def _sacreformat(refs, preds):
class _bootstrap_internal: class _bootstrap_internal:
def __init__(self, f, n): def __init__(self, f, n) -> None:
self.f = f self.f = f
self.n = n self.n = n
......
...@@ -13,7 +13,7 @@ from lm_eval.logger import eval_logger ...@@ -13,7 +13,7 @@ from lm_eval.logger import eval_logger
class LM(abc.ABC): class LM(abc.ABC):
def __init__(self): def __init__(self) -> None:
"""Defines the interface that should be implemented by all LM subclasses. """Defines the interface that should be implemented by all LM subclasses.
LMs are assumed to take text (strings) as input and yield strings as output LMs are assumed to take text (strings) as input and yield strings as output
(inputs/outputs should be tokenization-agnostic.) (inputs/outputs should be tokenization-agnostic.)
...@@ -133,7 +133,7 @@ class LM(abc.ABC): ...@@ -133,7 +133,7 @@ class LM(abc.ABC):
# not support multi-device parallelism nor expect it. # not support multi-device parallelism nor expect it.
return self._world_size return self._world_size
def set_cache_hook(self, cache_hook): def set_cache_hook(self, cache_hook) -> None:
self.cache_hook = cache_hook self.cache_hook = cache_hook
...@@ -144,14 +144,14 @@ def hash_args(attr, args): ...@@ -144,14 +144,14 @@ def hash_args(attr, args):
class CacheHook: class CacheHook:
def __init__(self, cachinglm): def __init__(self, cachinglm) -> None:
if cachinglm is None: if cachinglm is None:
self.dbdict = None self.dbdict = None
return return
self.dbdict = cachinglm.dbdict self.dbdict = cachinglm.dbdict
def add_partial(self, attr, req, res): def add_partial(self, attr, req, res) -> None:
if self.dbdict is None: if self.dbdict is None:
return return
hsh = hash_args(attr, req) hsh = hash_args(attr, req)
...@@ -159,7 +159,7 @@ class CacheHook: ...@@ -159,7 +159,7 @@ class CacheHook:
class CachingLM: class CachingLM:
def __init__(self, lm, cache_db): def __init__(self, lm, cache_db) -> None:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not. """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM :param lm: LM
......
class Sampler: class Sampler:
def __init__(self, docs, task, fewshot_indices=None, rnd=None): def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
self.rnd = rnd self.rnd = rnd
assert self.rnd, "must pass rnd to FewShotSampler!" assert self.rnd, "must pass rnd to FewShotSampler!"
...@@ -19,7 +18,6 @@ class Sampler: ...@@ -19,7 +18,6 @@ class Sampler:
self.docs = self.docs.select(fewshot_indices) self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc, num_fewshot): def get_context(self, doc, num_fewshot):
# draw an extra fewshot sample if using same split as evaluating on # draw an extra fewshot sample if using same split as evaluating on
n_samples = ( n_samples = (
num_fewshot + 1 num_fewshot + 1
...@@ -74,7 +72,7 @@ class Sampler: ...@@ -74,7 +72,7 @@ class Sampler:
class BalancedSampler(Sampler): class BalancedSampler(Sampler):
def sample(self, n): def sample(self, n) -> None:
""" """
TODO: this should return approximately class-balanced samples from our fewshot examples. TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random? TODO: what order should they be in? maybe random?
...@@ -84,7 +82,7 @@ class BalancedSampler(Sampler): ...@@ -84,7 +82,7 @@ class BalancedSampler(Sampler):
class ManualSampler(Sampler): class ManualSampler(Sampler):
def sample(self, n): def sample(self, n) -> None:
""" """ """ """
pass pass
......
...@@ -78,7 +78,7 @@ class TaskConfig(dict): ...@@ -78,7 +78,7 @@ class TaskConfig(dict):
# runtime configuration options # runtime configuration options
num_fewshot: int = 0 num_fewshot: int = 0
# scoring options # scoring options
metric_list: str = None metric_list: list = None
output_type: str = "greedy_until" output_type: str = "greedy_until"
generation_kwargs: dict = None generation_kwargs: dict = None
repeats: int = 1 repeats: int = 1
...@@ -88,7 +88,12 @@ class TaskConfig(dict): ...@@ -88,7 +88,12 @@ class TaskConfig(dict):
metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self): def __post_init__(self) -> None:
if "." in self.dataset_path:
import inspect
from importlib import import_module
self.dataset_path = inspect.getfile(import_module(self.dataset_path))
if self.generation_kwargs is not None: if self.generation_kwargs is not None:
if self.output_type != "greedy_until": if self.output_type != "greedy_until":
...@@ -171,7 +176,7 @@ class Task(abc.ABC): ...@@ -171,7 +176,7 @@ class Task(abc.ABC):
cache_dir=None, cache_dir=None,
download_mode=None, download_mode=None,
config=None, config=None,
): ) -> None:
""" """
:param data_dir: str :param data_dir: str
Stores the path to a local folder containing the `Task`'s data files. Stores the path to a local folder containing the `Task`'s data files.
...@@ -182,7 +187,6 @@ class Task(abc.ABC): ...@@ -182,7 +187,6 @@ class Task(abc.ABC):
HuggingFace `datasets` API with the default cache directory located at: HuggingFace `datasets` API with the default cache directory located at:
`~/.cache/huggingface/datasets` `~/.cache/huggingface/datasets`
NOTE: You can change the cache location globally for a given process NOTE: You can change the cache location globally for a given process
by setting the shell environment variable, `HF_DATASETS_CACHE`,
to another directory: to another directory:
`export HF_DATASETS_CACHE="/path/to/another/directory"` `export HF_DATASETS_CACHE="/path/to/another/directory"`
:param download_mode: datasets.DownloadMode :param download_mode: datasets.DownloadMode
...@@ -213,7 +217,7 @@ class Task(abc.ABC): ...@@ -213,7 +217,7 @@ class Task(abc.ABC):
list(self.fewshot_docs()), self, rnd=random.Random(1234) list(self.fewshot_docs()), self, rnd=random.Random(1234)
) )
def download(self, data_dir=None, cache_dir=None, download_mode=None): def download(self, data_dir=None, cache_dir=None, download_mode=None) -> None:
"""Downloads and returns the task dataset. """Downloads and returns the task dataset.
Override this method to download the dataset from a custom API. Override this method to download the dataset from a custom API.
...@@ -327,7 +331,7 @@ class Task(abc.ABC): ...@@ -327,7 +331,7 @@ class Task(abc.ABC):
return rnd.sample(self._training_docs, k) return rnd.sample(self._training_docs, k)
def doc_to_decontamination_query(self, doc): def doc_to_decontamination_query(self, doc) -> None:
print( print(
"Override doc_to_decontamination_query with document specific decontamination query." "Override doc_to_decontamination_query with document specific decontamination query."
) )
...@@ -341,7 +345,7 @@ class Task(abc.ABC): ...@@ -341,7 +345,7 @@ class Task(abc.ABC):
def doc_to_target(self, doc): def doc_to_target(self, doc):
pass pass
def build_all_requests(self, limit=None, rank=None, world_size=None): def build_all_requests(self, limit=None, rank=None, world_size=None) -> None:
"""Build a set of Instances for a task, and store them in task.instances""" """Build a set of Instances for a task, and store them in task.instances"""
if self.has_test_docs(): if self.has_test_docs():
docs = self.test_docs() docs = self.test_docs()
...@@ -477,7 +481,6 @@ class Task(abc.ABC): ...@@ -477,7 +481,6 @@ class Task(abc.ABC):
return labeled_examples + str(example) return labeled_examples + str(example)
def apply_filters(self): def apply_filters(self):
if hasattr(self, "_filters"): if hasattr(self, "_filters"):
for f in self._filters: for f in self._filters:
f.apply(self._instances) f.apply(self._instances)
...@@ -503,7 +506,7 @@ class ConfigurableTask(Task): ...@@ -503,7 +506,7 @@ class ConfigurableTask(Task):
def __init__( def __init__(
self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None
): # TODO no super() call here ) -> None: # TODO no super() call here
# Get pre-configured attributes # Get pre-configured attributes
self._config = self.CONFIG self._config = self.CONFIG
...@@ -575,7 +578,6 @@ class ConfigurableTask(Task): ...@@ -575,7 +578,6 @@ class ConfigurableTask(Task):
"aggregation" "aggregation"
] ]
else: else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()} INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = get_default_aggregation(metric_name) metric_agg = get_default_aggregation(metric_name)
eval_logger.warning( eval_logger.warning(
...@@ -632,19 +634,19 @@ class ConfigurableTask(Task): ...@@ -632,19 +634,19 @@ class ConfigurableTask(Task):
) )
if self.has_test_docs(): if self.has_test_docs():
docs = self.test_docs() self.task_docs = self.test_docs()
elif self.has_validation_docs(): elif self.has_validation_docs():
docs = self.validation_docs() self.task_docs = self.validation_docs()
else: else:
assert ( assert (
False False
), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!" ), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
# Test One Doc # Test One Doc
self.features = list(docs.features.keys()) self.features = list(self.task_docs.features.keys())
self.multiple_input = 0 self.multiple_input = 0
self.multiple_target = 0 self.multiple_target = 0
test_doc = docs[0] test_doc = self.task_docs[0]
test_text = self.doc_to_text(test_doc) test_text = self.doc_to_text(test_doc)
test_target = self.doc_to_target(test_doc) test_target = self.doc_to_target(test_doc)
...@@ -664,14 +666,14 @@ class ConfigurableTask(Task): ...@@ -664,14 +666,14 @@ class ConfigurableTask(Task):
self.multiple_target = len(test_target) self.multiple_target = len(test_target)
else: else:
if (type(test_target) is int) and (test_choice is not None): if (type(test_target) is int) and (test_choice is not None):
test_target = [self.doc_to_choice(test_target)[test_target]] test_target = test_choice[test_target]
else: else:
test_target = [test_target] test_target = str(test_target)
if test_choice is not None: if test_choice is not None:
check_choices = test_choice check_choices = test_choice
else: else:
check_choices = test_target check_choices = [test_target]
for choice in check_choices: for choice in check_choices:
choice_has_whitespace = True if " " in choice else False choice_has_whitespace = True if " " in choice else False
...@@ -688,8 +690,7 @@ class ConfigurableTask(Task): ...@@ -688,8 +690,7 @@ class ConfigurableTask(Task):
f'Both target_delimiter and target choice: "{choice}" does not have whitespace, ignore if the language you are evaluating on does not require/use whitespace' f'Both target_delimiter and target choice: "{choice}" does not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
) )
def download(self, dataset_kwargs=None): def download(self, dataset_kwargs=None) -> None:
self.dataset = datasets.load_dataset( self.dataset = datasets.load_dataset(
path=self.DATASET_PATH, path=self.DATASET_PATH,
name=self.DATASET_NAME, name=self.DATASET_NAME,
...@@ -748,6 +749,15 @@ class ConfigurableTask(Task): ...@@ -748,6 +749,15 @@ class ConfigurableTask(Task):
) )
return super().fewshot_docs() return super().fewshot_docs()
def apply_filters(self):
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances, self.task_docs)
else:
eval_logger.warning("No filter defined, passing through instances")
return self._instances
def should_decontaminate(self): def should_decontaminate(self):
return self.config.should_decontaminate return self.config.should_decontaminate
...@@ -772,7 +782,6 @@ class ConfigurableTask(Task): ...@@ -772,7 +782,6 @@ class ConfigurableTask(Task):
return doc return doc
def doc_to_text(self, doc): def doc_to_text(self, doc):
if self.prompt is not None: if self.prompt is not None:
doc_to_text = self.prompt doc_to_text = self.prompt
else: else:
...@@ -788,7 +797,7 @@ class ConfigurableTask(Task): ...@@ -788,7 +797,7 @@ class ConfigurableTask(Task):
return doc[doc_to_text] return doc[doc_to_text]
else: else:
text_string = utils.apply_template(doc_to_text, doc) text_string = utils.apply_template(doc_to_text, doc)
if text_string.isdigit(): if text_string.isdigit() and self._config.doc_to_choice is not None:
return ast.literal_eval(text_string) return ast.literal_eval(text_string)
else: else:
return text_string return text_string
...@@ -807,7 +816,6 @@ class ConfigurableTask(Task): ...@@ -807,7 +816,6 @@ class ConfigurableTask(Task):
raise TypeError raise TypeError
def doc_to_target(self, doc: dict) -> Union[int, str, list]: def doc_to_target(self, doc: dict) -> Union[int, str, list]:
if self.prompt is not None: if self.prompt is not None:
doc_to_target = self.prompt doc_to_target = self.prompt
else: else:
...@@ -823,7 +831,7 @@ class ConfigurableTask(Task): ...@@ -823,7 +831,7 @@ class ConfigurableTask(Task):
return doc[doc_to_target] return doc[doc_to_target]
else: else:
target_string = utils.apply_template(doc_to_target, doc) target_string = utils.apply_template(doc_to_target, doc)
if target_string.isdigit(): if target_string.isdigit() and self._config.doc_to_choice is not None:
return ast.literal_eval(target_string) return ast.literal_eval(target_string)
elif ( elif (
len(target_string) >= 2 len(target_string) >= 2
...@@ -849,7 +857,6 @@ class ConfigurableTask(Task): ...@@ -849,7 +857,6 @@ class ConfigurableTask(Task):
raise TypeError raise TypeError
def doc_to_choice(self, doc: Any) -> List[str]: def doc_to_choice(self, doc: Any) -> List[str]:
if self.prompt is not None: if self.prompt is not None:
doc_to_choice = self.prompt doc_to_choice = self.prompt
elif self.config.doc_to_choice is None: elif self.config.doc_to_choice is None:
...@@ -893,13 +900,11 @@ class ConfigurableTask(Task): ...@@ -893,13 +900,11 @@ class ConfigurableTask(Task):
def construct_requests( def construct_requests(
self, doc: dict, ctx: str, **kwargs self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]: ) -> Union[List[Instance], Instance]:
if self.OUTPUT_TYPE == "loglikelihood": if self.OUTPUT_TYPE == "loglikelihood":
arguments = (ctx, self.doc_to_target(doc)) arguments = (ctx, self.doc_to_target(doc))
elif self.OUTPUT_TYPE == "loglikelihood_rolling": elif self.OUTPUT_TYPE == "loglikelihood_rolling":
arguments = (self.doc_to_target(doc),) arguments = (self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
target_delimiter = self.config.target_delimiter target_delimiter = self.config.target_delimiter
if self.multiple_input: if self.multiple_input:
...@@ -985,7 +990,6 @@ class ConfigurableTask(Task): ...@@ -985,7 +990,6 @@ class ConfigurableTask(Task):
), ),
} }
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results) lls, is_greedy = zip(*results)
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
...@@ -1010,18 +1014,36 @@ class ConfigurableTask(Task): ...@@ -1010,18 +1014,36 @@ class ConfigurableTask(Task):
gold = self.doc_to_text(doc) gold = self.doc_to_text(doc)
else: else:
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
if type(gold) is str:
gold = choices.index(gold) gold_index_error = False
if type(gold) is list:
gold = [i if i < len(choices) else -100 for i in gold]
if -100 in gold:
gold_index_error = True
else:
if type(gold) is int:
gold = gold if gold < len(choices) else -100
elif type(gold) is str:
gold = choices.index(gold) if gold in choices else -100
if gold == -100:
gold_index_error = True
if gold_index_error:
eval_logger.warning(
f"Label index was not in within range of available choices,"
f"Sample:\n\n{doc}\n\n"
)
if self.multiple_target: if self.multiple_target:
acc = 1.0 if pred in gold else 0.0 acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0 acc_norm = 1.0 if pred_norm in gold else 0.0
exact_match = int(any([is_greedy[i] for i in gold])) exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
else: else:
acc = 1.0 if pred == gold else 0.0 acc = 1.0 if pred == gold else 0.0
acc_norm = 1.0 if pred_norm == gold else 0.0 acc_norm = 1.0 if pred_norm == gold else 0.0
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
exact_match = int(is_greedy[gold]) exact_match = int(is_greedy[gold]) if gold != -100 else 0
result_dict = { result_dict = {
**({"acc": acc} if "acc" in use_metric else {}), **({"acc": acc} if "acc" in use_metric else {}),
...@@ -1039,7 +1061,6 @@ class ConfigurableTask(Task): ...@@ -1039,7 +1061,6 @@ class ConfigurableTask(Task):
result_dict["acc_mutual_info"] = acc_mutual_info result_dict["acc_mutual_info"] = acc_mutual_info
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
if self.config.doc_to_choice is not None: if self.config.doc_to_choice is not None:
# If you set doc_to_choice, # If you set doc_to_choice,
...@@ -1049,37 +1070,45 @@ class ConfigurableTask(Task): ...@@ -1049,37 +1070,45 @@ class ConfigurableTask(Task):
else: else:
gold = str(gold) gold = str(gold)
for key, result in zip(self._metric_fn_list.keys(), results): result = results[0]
for metric in self._metric_fn_list.keys():
if self.multiple_target: if self.multiple_target:
# in the case where we have multiple targets, # in the case where we have multiple targets,
# return true if any are true # return true if any are true
# TODO: this may break for multipLe_target, non zero-or-1 metrics # TODO: this may break for multipLe_target, non zero-or-1 metrics
scores = [] scores = []
for gold_option in gold: for gold_option in gold:
res = self._metric_fn_list[key]( try:
references=[gold_option], result_score = self._metric_fn_list[metric](
predictions=[result], references=[gold_option],
**self._metric_fn_kwargs[key], predictions=[result],
) **self._metric_fn_kwargs[metric],
if isinstance(res, dict): )
except TypeError: # TODO: this is hacky and I don't want to do it
result_score = self._metric_fn_list[metric](
[gold_option, result]
)
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict. # TODO: this handles the case where HF evaluate returns a dict.
res = res[key] result_score = result_score[metric]
scores.append(res) scores.append(result_score)
if any(scores): if any(scores):
result_score = 1.0 result_score = 1.0
else: else:
result_score = 0.0 result_score = 0.0
else: else:
result_score = self._metric_fn_list[key]( try:
references=[gold], result_score = self._metric_fn_list[metric](
predictions=[result], references=[gold],
**self._metric_fn_kwargs[key], predictions=[result],
) **self._metric_fn_kwargs[metric],
)
if isinstance(result_score, dict): except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_dict.update(result_score) result_score = self._metric_fn_list[metric]([gold, result])
else: if isinstance(result_score, dict):
result_dict[key] = result_score # TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric]
result_dict[metric] = result_score
else: else:
raise ValueError( raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ", f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
...@@ -1169,7 +1198,7 @@ class PerplexityTask(Task): ...@@ -1169,7 +1198,7 @@ class PerplexityTask(Task):
def doc_to_decontamination_query(self, doc): def doc_to_decontamination_query(self, doc):
return doc return doc
def doc_to_text(self, doc): def doc_to_text(self, doc) -> str:
return "" return ""
def doc_to_target(self, doc): def doc_to_target(self, doc):
......
...@@ -11,8 +11,7 @@ from lm_eval.api.registry import ( ...@@ -11,8 +11,7 @@ from lm_eval.api.registry import (
) )
def include_benchmarks(task_dir): def include_benchmarks(task_dir: str) -> None:
for root, subdirs, file_list in os.walk(task_dir): for root, subdirs, file_list in os.walk(task_dir):
if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0): if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
for f in file_list: for f in file_list:
......
import os import os
from typing import Any
import zstandard import zstandard
import json import json
import jsonlines import jsonlines
...@@ -9,7 +10,7 @@ import tqdm ...@@ -9,7 +10,7 @@ import tqdm
from pathlib import Path from pathlib import Path
def json_serial(obj): def json_serial(obj: Any) -> str:
"""JSON serializer for objects not serializable by default json code""" """JSON serializer for objects not serializable by default json code"""
if isinstance(obj, (datetime.datetime,)): if isinstance(obj, (datetime.datetime,)):
...@@ -19,7 +20,7 @@ def json_serial(obj): ...@@ -19,7 +20,7 @@ def json_serial(obj):
# Modified version of lm_dataformat Archive for single file. # Modified version of lm_dataformat Archive for single file.
class Archive: class Archive:
def __init__(self, file_path, compression_level=3): def __init__(self, file_path: str, compression_level: int = 3) -> None:
self.file_path = file_path self.file_path = file_path
dir_name = os.path.dirname(file_path) dir_name = os.path.dirname(file_path)
if dir_name: if dir_name:
...@@ -28,7 +29,7 @@ class Archive: ...@@ -28,7 +29,7 @@ class Archive:
self.cctx = zstandard.ZstdCompressor(level=compression_level) self.cctx = zstandard.ZstdCompressor(level=compression_level)
self.compressor = self.cctx.stream_writer(self.fh) self.compressor = self.cctx.stream_writer(self.fh)
def add_data(self, data, meta={}): def add_data(self, data, meta={}) -> None:
self.compressor.write( self.compressor.write(
json.dumps({"text": data, "meta": meta}, default=json_serial).encode( json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
"UTF-8" "UTF-8"
...@@ -36,7 +37,7 @@ class Archive: ...@@ -36,7 +37,7 @@ class Archive:
+ b"\n" + b"\n"
) )
def commit(self): def commit(self) -> None:
self.compressor.flush(zstandard.FLUSH_FRAME) self.compressor.flush(zstandard.FLUSH_FRAME)
self.fh.flush() self.fh.flush()
self.fh.close() self.fh.close()
...@@ -44,10 +45,16 @@ class Archive: ...@@ -44,10 +45,16 @@ class Archive:
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm. # Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader: class Reader:
def __init__(self): def __init__(self) -> None:
pass pass
def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner="\n\n"): def read(
self,
file,
get_meta: bool = False,
autojoin_paragraphs: bool = True,
para_joiner: str = "\n\n",
):
with open(file, "rb") as fh: with open(file, "rb") as fh:
self.fh = fh self.fh = fh
cctx = zstandard.ZstdDecompressor() cctx = zstandard.ZstdDecompressor()
...@@ -72,7 +79,7 @@ class Reader: ...@@ -72,7 +79,7 @@ class Reader:
class TextArchive: class TextArchive:
def __init__(self, file_path, mode="rb+"): def __init__(self, file_path, mode: str = "rb+") -> None:
self.file_path = file_path self.file_path = file_path
dir_name = os.path.dirname(file_path) dir_name = os.path.dirname(file_path)
if dir_name: if dir_name:
...@@ -83,21 +90,21 @@ class TextArchive: ...@@ -83,21 +90,21 @@ class TextArchive:
self.fh = open(self.file_path, mode) self.fh = open(self.file_path, mode)
def add_data(self, data): def add_data(self, data) -> None:
self.fh.write(data.encode("UTF-8") + b"\n") self.fh.write(data.encode("UTF-8") + b"\n")
def commit(self): def commit(self) -> None:
self.fh.flush() self.fh.flush()
self.fh.close() self.fh.close()
class TextReader: class TextReader:
def __init__(self, file_path): def __init__(self, file_path) -> None:
self.file_path = file_path self.file_path = file_path
# Optimized mmap read with infrequent tqdm updates to maintain speed # Optimized mmap read with infrequent tqdm updates to maintain speed
# Tested up to 250MB/s. # Tested up to 250MB/s.
def read_tqdm(self, update_frequency=10000): def read_tqdm(self, update_frequency: int = 10000):
current_file_position = 0 current_file_position = 0
line_counter = 0 line_counter = 0
with open(self.file_path, "r") as fh, tqdm.tqdm( with open(self.file_path, "r") as fh, tqdm.tqdm(
...@@ -149,7 +156,7 @@ class TextReader: ...@@ -149,7 +156,7 @@ class TextReader:
# Optimized for speed. Decompresses the archive in shell before # Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader. # using the mmap'd TextReader.
class ZStdTextReader: class ZStdTextReader:
def __init__(self, file): def __init__(self, file) -> None:
self.file = file self.file = file
def read_tqdm(self): def read_tqdm(self):
......
...@@ -11,7 +11,7 @@ from .archiver import ZStdTextReader ...@@ -11,7 +11,7 @@ from .archiver import ZStdTextReader
# Was used for testing the evaluator decoupled from the full logic below # Was used for testing the evaluator decoupled from the full logic below
def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size): def get_train_overlap_stub(docs: dict, ngrams_path: str, ngrams_n_size: str):
simulated_overlap = 0.1 simulated_overlap = 0.1
contaminated = int(len(docs) * simulated_overlap) contaminated = int(len(docs) * simulated_overlap)
return random.sample(range(len(docs)), contaminated) return random.sample(range(len(docs)), contaminated)
...@@ -25,6 +25,7 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size): ...@@ -25,6 +25,7 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst" # scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
# files. These should exist in the "ngrams_path" provided to this function. # files. These should exist in the "ngrams_path" provided to this function.
# Algorithm: # Algorithm:
# 1. Build lookups for each dataset {ngram: list(document_ids)} # 1. Build lookups for each dataset {ngram: list(document_ids)}
# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]} # 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
...@@ -33,7 +34,7 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size): ...@@ -33,7 +34,7 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
# 4. Strip the task_set from the dictionary keys and return # 4. Strip the task_set from the dictionary keys and return
# #
# We cache the task+set lookups as well as the overlaps. # We cache the task+set lookups as well as the overlaps.
def get_train_overlap(docs_by_task_set, ngrams_path, limit): def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> dict:
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size) # return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
info_dict_path = os.path.join(ngrams_path, "info.json") info_dict_path = os.path.join(ngrams_path, "info.json")
...@@ -46,7 +47,7 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit): ...@@ -46,7 +47,7 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
print("Building Lookups...") print("Building Lookups...")
start = time.perf_counter() start = time.perf_counter()
def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit): def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit) -> str:
return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps" return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps"
lookups = {} lookups = {}
......
import re import re
import string import string
import timeit
import pickle import pickle
import traceback import traceback
from pprint import pprint from pprint import pprint
from typing import Iterator, Sequence, TypeVar
# This is a cpp module. Compile janitor_util.cpp with: # This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup # c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
...@@ -16,10 +16,12 @@ except Exception: ...@@ -16,10 +16,12 @@ except Exception:
traceback.print_exc() traceback.print_exc()
JANITOR_CPP = False JANITOR_CPP = False
T = TypeVar("T")
# Implementation from nltk source # Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html # https://www.nltk.org/_modules/nltk/util.html
def form_ngrams(sequence, n): def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[tuple[T, ...]]:
history = [] history = []
while n > 1: while n > 1:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator # PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
...@@ -36,7 +38,7 @@ def form_ngrams(sequence, n): ...@@ -36,7 +38,7 @@ def form_ngrams(sequence, n):
del history[0] del history[0]
def word_ngrams(s, n): def word_ngrams(s: str, n: int) -> Iterator[str]:
"""Splits a string into ngram words""" """Splits a string into ngram words"""
tokens = s.split() # not a generator :( tokens = s.split() # not a generator :(
ngram_seqs = form_ngrams(iter(tokens), n) ngram_seqs = form_ngrams(iter(tokens), n)
...@@ -68,14 +70,14 @@ def word_ngrams(s, n): ...@@ -68,14 +70,14 @@ def word_ngrams(s, n):
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python # https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
def split_indices(s): def split_indices(s: str) -> Iterator[tuple[str, tuple[int, int]]]:
"""Splits a string on whitespaces and records the indices of each in the original string. """Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...) @:return generator((word, (start_idx, end_idx)), ...)
""" """
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s)) return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
def word_ngrams_indices(s, n): def word_ngrams_indices(s: str, n: int) -> Iterator[tuple[str, tuple[int, int]]]:
"""Splits a string into pairs of (ngram words, their start/end indices)""" """Splits a string into pairs of (ngram words, their start/end indices)"""
tokens_with_indices = split_indices(s) tokens_with_indices = split_indices(s)
...@@ -104,16 +106,15 @@ def word_ngrams_indices(s, n): ...@@ -104,16 +106,15 @@ def word_ngrams_indices(s, n):
class Janitor: class Janitor:
# FIXME delete_chars: Should anything else go here? Special chars? # FIXME delete_chars: Should anything else go here? Special chars?
def __init__( def __init__(
self, self,
ngram_n=13, ngram_n: int = 13,
window_to_remove=200, window_to_remove: int = 200,
too_dirty_cutoff=10, too_dirty_cutoff: int = 10,
minimum_slice_length=200, minimum_slice_length: int = 200,
delete_chars=string.punctuation, delete_chars: str = string.punctuation,
): ) -> None:
self.ngram_n = ngram_n self.ngram_n = ngram_n
self.window_to_remove = window_to_remove self.window_to_remove = window_to_remove
self.too_dirty_cutoff = too_dirty_cutoff self.too_dirty_cutoff = too_dirty_cutoff
...@@ -135,11 +136,11 @@ class Janitor: ...@@ -135,11 +136,11 @@ class Janitor:
# I/O for saving contamination ngrams # I/O for saving contamination ngrams
############## ##############
def save_contamination_ngrams(self, filename): def save_contamination_ngrams(self, filename: str) -> None:
with open(filename, "wb") as fp: with open(filename, "wb") as fp:
pickle.dump(filename, fp) pickle.dump(filename, fp)
def load_contamination_ngrams(self, filename): def load_contamination_ngrams(self, filename: str) -> None:
with open(filename, "rb") as fp: with open(filename, "rb") as fp:
self.dirt_ngrams = pickle.load(fp) self.dirt_ngrams = pickle.load(fp)
...@@ -147,7 +148,7 @@ class Janitor: ...@@ -147,7 +148,7 @@ class Janitor:
# Call these :) # Call these :)
############## ##############
def register_contaminant(self, dirt_string): def register_contaminant(self, dirt_string: str) -> None:
"""Register a string as contamination to be removed, e.g. a test set """Register a string as contamination to be removed, e.g. a test set
This breaks the dirt_string into ngrams to store for future cleaning""" This breaks the dirt_string into ngrams to store for future cleaning"""
if JANITOR_CPP: if JANITOR_CPP:
...@@ -156,7 +157,7 @@ class Janitor: ...@@ -156,7 +157,7 @@ class Janitor:
print("WARNING: Janitor running in python mode") print("WARNING: Janitor running in python mode")
return self.register_contaminant_python(dirt_string) return self.register_contaminant_python(dirt_string)
def clean(self, dirty_string): def clean(self, dirty_string: str) -> list[str]:
"""Clean a string (e.g. a training set) by removing all ngrams previously """Clean a string (e.g. a training set) by removing all ngrams previously
registered as contaminants. Returns a list of clean chunks, or empty if registered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty""" the string was too dirty"""
...@@ -166,7 +167,9 @@ class Janitor: ...@@ -166,7 +167,9 @@ class Janitor:
print("WARNING: Janitor running in python mode") print("WARNING: Janitor running in python mode")
return self.clean_python(dirty_string) return self.clean_python(dirty_string)
def _split_chunks(self, dirty_string, dirty_parts): def _split_chunks(
self, dirty_string: str, dirty_parts: Sequence[tuple]
) -> list[str]:
clean_chunks = [] clean_chunks = []
splice_idx = 0 splice_idx = 0
end = -1 end = -1
...@@ -189,12 +192,12 @@ class Janitor: ...@@ -189,12 +192,12 @@ class Janitor:
# Fast C++ # Fast C++
############## ##############
def register_contaminant_cpp(self, dirt_string): def register_contaminant_cpp(self, dirt_string) -> None:
self.dirt_ngrams.update( self.dirt_ngrams.update(
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n) janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
) )
def clean_cpp(self, dirty_string): def clean_cpp(self, dirty_string: str) -> list[str]:
contamination_indices = janitor_util.clean_ngram_with_indices( contamination_indices = janitor_util.clean_ngram_with_indices(
dirty_string, self.delete_chars, self.ngram_n dirty_string, self.delete_chars, self.ngram_n
) )
...@@ -204,15 +207,15 @@ class Janitor: ...@@ -204,15 +207,15 @@ class Janitor:
# Slow python # Slow python
############## ##############
def normalize_string(self, s): def normalize_string(self, s: str) -> str:
return s.translate(self.translation_table) return s.translate(self.translation_table)
def register_contaminant_python(self, dirt_string): def register_contaminant_python(self, dirt_string: str) -> None:
self.dirt_ngrams.update( self.dirt_ngrams.update(
word_ngrams(self.normalize_string(dirt_string), self.ngram_n) word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
) )
def clean_python(self, dirty_string): def clean_python(self, dirty_string: str) -> list[str]:
contamination_indices = ( contamination_indices = (
(None, *idx_pair) (None, *idx_pair)
for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n) for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
......
...@@ -42,11 +42,11 @@ def simple_evaluate( ...@@ -42,11 +42,11 @@ def simple_evaluate(
device=None, device=None,
use_cache=None, use_cache=None,
limit=None, limit=None,
bootstrap_iters=100000, bootstrap_iters: int = 100000,
check_integrity=False, check_integrity: bool = False,
decontamination_ngrams_path=None, decontamination_ngrams_path=None,
write_out=False, write_out: bool = False,
log_samples=True, log_samples: bool = True,
): ):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
...@@ -117,7 +117,6 @@ def simple_evaluate( ...@@ -117,7 +117,6 @@ def simple_evaluate(
task_dict = lm_eval.tasks.get_task_dict(tasks) task_dict = lm_eval.tasks.get_task_dict(tasks)
for task_name in task_dict.keys(): for task_name in task_dict.keys():
task_obj = task_dict[task_name] task_obj = task_dict[task_name]
if type(task_obj) == tuple: if type(task_obj) == tuple:
group, task_obj = task_obj group, task_obj = task_obj
...@@ -175,10 +174,10 @@ def evaluate( ...@@ -175,10 +174,10 @@ def evaluate(
lm, lm,
task_dict, task_dict,
limit=None, limit=None,
bootstrap_iters=100000, bootstrap_iters: int = 100000,
decontamination_ngrams_path=None, decontamination_ngrams_path=None,
write_out=False, write_out: bool = False,
log_samples=True, log_samples: bool = True,
): ):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
...@@ -219,15 +218,14 @@ def evaluate( ...@@ -219,15 +218,14 @@ def evaluate(
padding_requests = collections.defaultdict(int) padding_requests = collections.defaultdict(int)
# Stores group related keys and values for group-aggregation # Stores group related keys and values for group-aggregation
aggregate = collections.defaultdict(dict)
task_groups = collections.defaultdict(dict) task_groups = collections.defaultdict(dict)
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if type(task) == tuple: if type(task) == tuple:
group, task = task group, task = task
task_groups[task_name] = group task_groups[task_name] = group
aggregate[task_name] = {}
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
configs[task_name] = dict(task.dump_config()) configs[task_name] = dict(task.dump_config())
...@@ -252,7 +250,8 @@ def evaluate( ...@@ -252,7 +250,8 @@ def evaluate(
# print the prompt for the first few documents # print the prompt for the first few documents
if inst.doc_id < 1: if inst.doc_id < 1:
eval_logger.info( eval_logger.info(
f"Task: {task_name}; document {inst.doc_id}; context prompt (starting on next line):\n{inst.args[0]}\n(end of prompt on previous line)" f"Task: {task_name}; document {inst.doc_id}; context prompt (starting on next line):\
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
) )
eval_logger.info(f"Request: {str(inst)}") eval_logger.info(f"Request: {str(inst)}")
...@@ -349,7 +348,6 @@ def evaluate( ...@@ -349,7 +348,6 @@ def evaluate(
# if multigpu, then gather data across all ranks # if multigpu, then gather data across all ranks
# first gather logged samples across all ranks # first gather logged samples across all ranks
for task_name, task_samples in list(samples.items()): for task_name, task_samples in list(samples.items()):
full_samples = [None] * lm.world_size full_samples = [None] * lm.world_size
torch.distributed.all_gather_object(full_samples, task_samples) torch.distributed.all_gather_object(full_samples, task_samples)
...@@ -358,33 +356,39 @@ def evaluate( ...@@ -358,33 +356,39 @@ def evaluate(
# then collect metrics across all ranks # then collect metrics across all ranks
vals_torch = collections.defaultdict(list) vals_torch = collections.defaultdict(list)
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
numitem = 0 numitem = 0
if type(items[0]) == tuple: if type(items[0]) == tuple:
numitem = len(items[0]) numitem = len(items[0])
# distributed gather requires all ranks to have same dimensions if isinstance(items[0], (str, list)):
# so we pad out with float32 min value # handle the string case
pad_value = torch.finfo(torch.float32).min gathered_items = [None] * lm.accelerator.num_processes
metrics_tensor = torch.tensor(items, device=lm.device) torch.distributed.all_gather_object(gathered_items, items)
original_dtype = metrics_tensor.dtype # store original dtype gathered_item = list(itertools.chain.from_iterable(gathered_items))
torch_device_tensor = lm.accelerator.pad_across_processes(
metrics_tensor.to(torch.float32), pad_index=pad_value
)
gathered_item = lm.accelerator.gather(torch_device_tensor)
if numitem > 0:
gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
else: else:
gathered_filtered = gathered_item[gathered_item != pad_value] # distributed gather requires all ranks to have same dimensions
# so we pad out with float32 min value
pad_value = torch.finfo(torch.float32).min
metrics_tensor = torch.tensor(items, device=lm.device)
original_dtype = metrics_tensor.dtype # store original dtype
torch_device_tensor = lm.accelerator.pad_across_processes(
metrics_tensor.to(torch.float32), pad_index=pad_value
)
gathered_item = lm.accelerator.gather(torch_device_tensor)
gathered_item = ( if numitem > 0:
gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist() gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
) else:
# reconvert if we were passed a tuple of values gathered_filtered = gathered_item[gathered_item != pad_value]
if numitem > 0:
gathered_item = [tuple(g) for g in gathered_item] gathered_item = (
gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
)
# reconvert if we were passed a tuple of values
if numitem > 0:
gathered_item = [tuple(g) for g in gathered_item]
if lm.rank == 0: if lm.rank == 0:
vals_torch[(task_name, key, metric)] = gathered_item vals_torch[(task_name, key, metric)] = gathered_item
...@@ -407,16 +411,16 @@ def evaluate( ...@@ -407,16 +411,16 @@ def evaluate(
# | word_perplexity # | word_perplexity
# | byte_perplexity # | byte_perplexity
# | bits_per_byte # | bits_per_byte
if bool(task_groups): if task_name in task_groups:
group_name = task_groups[task_name] group_name = task_groups[task_name]
if metric not in aggregate[group_name]: if metric in list(aggregate[group_name].keys()):
aggregate[group_name][metric] = [task_score]
else:
aggregate[group_name][metric].append(task_score) aggregate[group_name][metric].append(task_score)
else:
aggregate[group_name][metric] = [task_score]
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
if bootstrap_iters > 0: if False: # bootstrap_iters > 0:
stderr = lm_eval.api.metrics.stderr_for_metric( stderr = lm_eval.api.metrics.stderr_for_metric(
metric=task.aggregation()[metric], metric=task.aggregation()[metric],
bootstrap_iters=min(bootstrap_iters, 1000) bootstrap_iters=min(bootstrap_iters, 1000)
......
...@@ -17,14 +17,16 @@ FILTER_REGISTRY = { ...@@ -17,14 +17,16 @@ FILTER_REGISTRY = {
def get_filter(filter_name): def get_filter(filter_name):
return FILTER_REGISTRY[filter_name] if filter_name in FILTER_REGISTRY:
return FILTER_REGISTRY[filter_name]
else:
return filter_name
def build_filter_ensemble(filter_name, components): def build_filter_ensemble(filter_name, components):
""" """
Create a filtering pipeline. Create a filtering pipeline.
""" """
filters = [] filters = []
for (function, kwargs) in components: for (function, kwargs) in components:
if kwargs is None: if kwargs is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment