Commit 2184b8de authored by lintangsutawika's avatar lintangsutawika
Browse files

Merge branch 'cont-metrics' of https://github.com/EleutherAI/lm-evaluation-harness into alt_worlds

parents b1ba4e71 1522009c
...@@ -43,7 +43,7 @@ jobs: ...@@ -43,7 +43,7 @@ jobs:
# # mypy turned off for now # # mypy turned off for now
# - name: Lint with mypy # - name: Lint with mypy
# run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable # run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable
Job 2 # Job 2
testcpu: testcpu:
name: CPU Tests name: CPU Tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
......
* @haileyschoelkopf @lintangsutawika * @haileyschoelkopf @lintangsutawika @StellaAthena
# Language Model Evaluation Harness # Language Model Evaluation Harness
## Notice to Users
(as of 6/15/23)
We have a revamp of the Evaluation Harness library internals staged on the [big-refactor](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) branch! It is far along in progress, but before we start to move the `master` branch of the repository over to this new design with a new version release, we'd like to ensure that it's been tested by outside users and there are no glaring bugs.
We’d like your help to test it out! you can help by:
1. Trying out your current workloads on the big-refactor branch, and seeing if anything breaks or is counterintuitive,
2. Porting tasks supported in the previous version of the harness to the new YAML configuration format. Please check out our [task implementation guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/new_task_guide.md) for more information.
If you choose to port a task not yet completed according to [our checklist](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/lm_eval/tasks/README.md), then you can contribute it by opening a PR containing [Refactor] in the name with:
- A command of the form `python -m lm_eval --model hf --model_args ..... --tasks <task name> ...` which will run the task in the `master` branch, and what the score is
- A command of the form `python -m lm_eval --model hf --model_args ..... --tasks <task name> ...` to run the task in your PR branch to `big-refactor`, and what the resulting score is, to show that we achieve equality between the two implementations.
Lastly, we'll no longer be accepting new feature requests beyond those that are already open to the master branch as we carry out this switch to the new version over the next week, though we will be accepting bugfixes to `master` branch and PRs to `big-refactor`. Feel free to reach out in the #lm-thunderdome channel of the EAI discord for more information.
## Overview ## Overview
This project provides a unified framework to test generative language models on a large number of different evaluation tasks. This project provides a unified framework to test generative language models on a large number of different evaluation tasks.
Features: Features:
- Many tasks implemented, 200+ tasks [implemented in the old framework](https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md) which require porting to the new setup as described in [the new task guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/new_task_guide.md). - Over 60 standard academic benchmarks for LLMs, with hundreds of subtasks and variants implemented.
- Support for models loaded via [transformers](https://github.com/huggingface/transformers/) (including quantization via [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ)), [GPT-NeoX](https://github.com/EleutherAI/gpt-neox), and [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed/), with a flexible tokenization-agnostic interface. - Support for models loaded via [transformers](https://github.com/huggingface/transformers/) (including quantization via [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ)), [GPT-NeoX](https://github.com/EleutherAI/gpt-neox), and [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed/), with a flexible tokenization-agnostic interface.
- Support for commercial APIs including [OpenAI](https://openai.com), [goose.ai](https://goose.ai), and [TextSynth](https://textsynth.com/). - Support for commercial APIs including [OpenAI](https://openai.com), [goose.ai](https://goose.ai), and [TextSynth](https://textsynth.com/).
- Support for evaluation on adapters (e.g. LoRa) supported in [HuggingFace's PEFT library](https://github.com/huggingface/peft). - Support for evaluation on adapters (e.g. LoRA) supported in [HuggingFace's PEFT library](https://github.com/huggingface/peft).
- Evaluating with publicly available prompts ensures reproducibility and comparability between papers. - Support for local models and benchmarks.
- Evaluation with publicly available prompts ensures reproducibility and comparability between papers.
The Language Model Evaluation Harness is the backend for 🤗 Hugging Face's popular [Open LLM Leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard) and is used internally by dozens of companies including NVIDIA, Cohere, Booz Allen Hamilton, and Mosaic ML.
## Install ## Install
To install the `lm-eval` refactor branch from the github repository, run: To install the `lm-eval` package from the github repository, run:
```bash ```bash
git clone https://github.com/EleutherAI/lm-evaluation-harness git clone https://github.com/EleutherAI/lm-evaluation-harness
...@@ -54,7 +44,6 @@ To install the package with all extras, run ...@@ -54,7 +44,6 @@ To install the package with all extras, run
pip install -e ".[all]" pip install -e ".[all]"
``` ```
## Support ## Support
The best way to get support is to open an issue on this repo or join the EleutherAI discord server](discord.gg/eleutherai). The `#lm-thunderdome` channel is dedicated to developing this project and the `#release-discussion` channel is for receiving support for our releases. The best way to get support is to open an issue on this repo or join the EleutherAI discord server](discord.gg/eleutherai). The `#lm-thunderdome` channel is dedicated to developing this project and the `#release-discussion` channel is for receiving support for our releases.
...@@ -86,7 +75,7 @@ python -m lm_eval \ ...@@ -86,7 +75,7 @@ python -m lm_eval \
--batch_size 8 --batch_size 8
``` ```
Models that are loaded via either `transformers.AutoModelForCausalLM` (autoregressive, decoder-only GPT style models) or `transformers.AutoModelForSeq2SeqLM` (such as encoder-decoder models like T5) in Huggingface are supported via Support for this model type is currently pending. Models that are loaded via both `transformers.AutoModelForCausalLM` (autoregressive, decoder-only GPT style models) and `transformers.AutoModelForSeq2SeqLM` (such as encoder-decoder models like T5) in Huggingface are supporteded.
Batch size selection can be automated by setting the ```--batch_size``` flag to ```auto```. This will perform automatic detection of the largest batch size that will fit on your device. On tasks where there is a large difference between the longest and shortest example, it can be helpful to periodically recompute the largest batch size, to gain a further speedup. To do this, append ```:N``` to above flag to automatically recompute the largest batch size ```N``` times. For example, to recompute the batch size 4 times, the command would be: Batch size selection can be automated by setting the ```--batch_size``` flag to ```auto```. This will perform automatic detection of the largest batch size that will fit on your device. On tasks where there is a large difference between the longest and shortest example, it can be helpful to periodically recompute the largest batch size, to gain a further speedup. To do this, append ```:N``` to above flag to automatically recompute the largest batch size ```N``` times. For example, to recompute the batch size 4 times, the command would be:
...@@ -151,14 +140,14 @@ A full accounting of the supported and planned libraries + APIs can be seen belo ...@@ -151,14 +140,14 @@ A full accounting of the supported and planned libraries + APIs can be seen belo
| API or Inference Server | Implemented? | `--model <xxx>` name | Models supported: | Request Types: | | API or Inference Server | Implemented? | `--model <xxx>` name | Models supported: | Request Types: |
|-----------------------------|---------------------------------|----------------------------------------------------------------------------------|--------------------------------------|----------------------------------------------------------| |-----------------------------|---------------------------------|----------------------------------------------------------------------------------|--------------------------------------|----------------------------------------------------------|
| OpenAI Completions | :heavy_check_mark: | `openai`, `openai-completions`, `gooseai` | up to `code-davinci-002` | `greedy_until`, `loglikelihood`, `loglikelihood_rolling` | | OpenAI Completions | :heavy_check_mark: | `openai`, `openai-completions`, `gooseai` | up to `code-davinci-002` | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| OpenAI ChatCompletions | :x: Not yet - needs help! | N/A | (link here?) | `greedy_until` (no logprobs) | | OpenAI ChatCompletions | :x: Not yet - needs testing! | N/A | [All ChatCompletions API models](https://platform.openai.com/docs/guides/gpt) | `generate_until` (no logprobs) |
| Anthropic | :heavy_check_mark: | `anthropic` | [Supported Anthropic Engines](https://docs.anthropic.com/claude/reference/selecting-a-model) | `greedy_until` (no logprobs) | | Anthropic | :heavy_check_mark: | `anthropic` | [Supported Anthropic Engines](https://docs.anthropic.com/claude/reference/selecting-a-model) | `generate_until` (no logprobs) |
| GooseAI | :heavy_check_mark: (not separately maintained) | `openai`, `openai-completions`, `gooseai` (same interface as OpenAI Completions) | | `greedy_until`, `loglikelihood`, `loglikelihood_rolling` | | GooseAI | :heavy_check_mark: (not separately maintained) | `openai`, `openai-completions`, `gooseai` (same interface as OpenAI Completions) | | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| Textsynth | Needs testing | `textsynth` | ??? | `greedy_until`, `loglikelihood`, `loglikelihood_rolling` | | Textsynth | Needs testing | `textsynth` | ??? | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| Cohere | :hourglass: - blocked on Cohere API bug | N/A | [All `cohere.generate()` engines](https://docs.cohere.com/docs/models) | `greedy_until`, `loglikelihood`, `loglikelihood_rolling` | | Cohere | :hourglass: - blocked on Cohere API bug | N/A | [All `cohere.generate()` engines](https://docs.cohere.com/docs/models) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| GGML | :hourglass: [PR](https://github.com/EleutherAI/lm-evaluation-harness/pull/617) | N/A | ??? | `greedy_until`, `loglikelihood`, `loglikelihood_rolling` | | GGML | :hourglass: [PR](https://github.com/EleutherAI/lm-evaluation-harness/pull/617) | N/A | ??? | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| vLLM | :x: Not yet - needs help! | N/A | All HF models | `greedy_until` (no logprobs) | | vLLM | :x: Not yet - needs help! | N/A | All HF models | `generate_until` (no logprobs) |
| Your inference server here! | ... | ... | ... | ... | | ... | | Your inference server here! | ... | ... | ... | ... | | ... |
It is on our roadmap to create task variants designed to enable models which do not serve logprobs/loglikelihoods to be compared with generation performance of open-source models. It is on our roadmap to create task variants designed to enable models which do not serve logprobs/loglikelihoods to be compared with generation performance of open-source models.
...@@ -227,12 +216,6 @@ python -m lm_eval \ ...@@ -227,12 +216,6 @@ python -m lm_eval \
We support wildcards in task names, for example you can run all of the machine-translated lambada tasks via `--task lambada_openai_mt_*`. We support wildcards in task names, for example you can run all of the machine-translated lambada tasks via `--task lambada_openai_mt_*`.
## Implementing new tasks
To implement a new task in the eval harness, see [this guide](./docs/new_task_guide.md).
As a start, we currently only support one prompt per task, which we strive to make the "standard" as defined by the benchmark's authors. If you would like to study how varying prompts causes changes in the evaluation score, we support prompts authored in the [Promptsource Library](https://github.com/bigscience-workshop/promptsource/tree/main) as described further in https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/lm_eval/docs/new_task_guide.md and https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/lm_eval/docs/advanced_task_guide.md and welcome contributions of novel task templates and task variants.
## How to Contribute or Learn More? ## How to Contribute or Learn More?
...@@ -241,35 +224,19 @@ For more information on the library and how everything fits together, check out ...@@ -241,35 +224,19 @@ For more information on the library and how everything fits together, check out
You can also ask for help, or discuss new features with the maintainers in the #lm-thunderdome channel of the EleutherAI discord! If you've used the library and have had a positive (or negative) experience, we'd love to hear from you! You can also ask for help, or discuss new features with the maintainers in the #lm-thunderdome channel of the EleutherAI discord! If you've used the library and have had a positive (or negative) experience, we'd love to hear from you!
### Implementing new tasks
To implement a new task in the eval harness, see [this guide](./docs/new_task_guide.md).
As a start, we currently only support one prompt per task, which we strive to make the "standard" as defined by the benchmark's authors. If you would like to study how varying prompts causes changes in the evaluation score, we support prompts authored in the [Promptsource Library](https://github.com/bigscience-workshop/promptsource/tree/main) as described further in [the task guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/lm_eval/docs/new_task_guide.md) and [the advanced task guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/lm_eval/docs/advanced_task_guide.md) and welcome contributions of novel task templates and task variants.
## Cite as ## Cite as
``` ```
@software{eval-harness, @misc{eval-harness,
author = {Gao, Leo and author = {Gao, Leo and Tow, Jonathan and Abbasi, Baber and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and Le Noac'h, Alain and Li, Haonan and McDonell, Kyle and Muennighoff, Niklas and Ociepa, Chris and Phang, Jason and Reynolds, Laria and Schoelkopf, Hailey and Skowron, Aviya and Sutawika, Lintang and Tang, Eric and Thite, Anish and Wang, Ben and Wang, Kevin and Zou, Andy},
Tow, Jonathan and
Abbasi, Baber and
Biderman, Stella and
Black, Sid and
DiPofi, Anthony and
Foster, Charles and
Golding, Laurence and
Hsu, Jeffrey and
Le Noac'h, Alain and
Li, Haonan and
McDonell, Kyle and
Muennighoff, Niklas and
Ociepa, Chris
Phang, Jason and
Reynolds, Laria and
Schoelkopf, Hailey and
Skowron, Aviya and
Sutawika, Lintang and
Tang, Eric and
Thite, Anish and
Wang, Ben and
Wang, Kevin and
Zou, Andy},
title = {A framework for few-shot language model evaluation}, title = {A framework for few-shot language model evaluation},
month = sep, month = sep,
year = 2021, year = 2021,
......
...@@ -57,7 +57,7 @@ import lm_eval ...@@ -57,7 +57,7 @@ import lm_eval
my_model = initialize_my_model() # create your model (could be running finetuning with some custom modeling code) my_model = initialize_my_model() # create your model (could be running finetuning with some custom modeling code)
... ...
lm_obj = Your_LM(model=my_model, batch_size=16) # instantiate an LM subclass that takes your initialized model and can run `Your_LM.loglikelihood()`, `Your_LM.loglikelihood_rolling()`, `Your_LM.greedy_until()` lm_obj = Your_LM(model=my_model, batch_size=16) # instantiate an LM subclass that takes your initialized model and can run `Your_LM.loglikelihood()`, `Your_LM.loglikelihood_rolling()`, `Your_LM.generate_until()`
results = lm_eval.simple_evaluate( # call simple_evaluate results = lm_eval.simple_evaluate( # call simple_evaluate
model=lm_obj, model=lm_obj,
...@@ -83,7 +83,7 @@ from my_tasks import MyTask1 # suppose you've defined a custom lm_eval.api.Task ...@@ -83,7 +83,7 @@ from my_tasks import MyTask1 # suppose you've defined a custom lm_eval.api.Task
my_model = initialize_my_model() # create your model (could be running finetuning with some custom modeling code) my_model = initialize_my_model() # create your model (could be running finetuning with some custom modeling code)
... ...
lm_obj = Your_LM(model=my_model, batch_size=16) # instantiate an LM subclass that takes your initialized model and can run `Your_LM.loglikelihood()`, `Your_LM.loglikelihood_rolling()`, `Your_LM.greedy_until()` lm_obj = Your_LM(model=my_model, batch_size=16) # instantiate an LM subclass that takes your initialized model and can run `Your_LM.loglikelihood()`, `Your_LM.loglikelihood_rolling()`, `Your_LM.generate_until()`
......
...@@ -44,28 +44,49 @@ class MyCustomLM(LM): ...@@ -44,28 +44,49 @@ class MyCustomLM(LM):
#... #...
def greedy_until(self, requests: list[Instance]) -> list[str]: def generate_until(self, requests: list[Instance]) -> list[str]:
#... #...
#... #...
``` ```
Where `Instance` is a dataclass defined in [`lm_eval.api.instance`](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/lm_eval/api/instance.py) with property `args` which returns a tuple of (context, continuation). Where `Instance` is a dataclass defined in [`lm_eval.api.instance`](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/lm_eval/api/instance.py) with property `args` of request-dependent type signature described below.
We support We support three types of requests, consisting of different interactions / measurements with an autoregressive LM.
The three types of All three request types take as input `requests` of type `list[Instance]` that have a matching `Instance.request_type` to the method name.
- `generate_until`
- Each request contains `Instance.args : Tuple[str, dict]` containing 1. an input string to the LM and 2. a dictionary of keyword arguments used to control generation parameters.
- Using this input and these generation parameters, text will be sampled from the language model (typically until a maximum output length or specific stopping string sequences--for example, `{"until": ["\n\n", "."], "max_gen_toks": 128}`).
- The generated input+output text from the model will then be returned.
- `loglikelihood`
- Each request contains `Instance.args : Tuple[str, str]` containing 1. an input string to the LM and 2. a target string on which the loglikelihood of the LM producing this target, conditioned on the input, will be returned.
- Each request will have, as result, `(ll, is_greedy): Tuple[float, int]` returned, where `ll` is a floating point number representing the log probability of generating the target string conditioned on the input, and `is_greedy` being either the value `0` or `1`, with it being `1` if and only if the target string *would be generated by greedy sampling from the LM* (that is, if the target string is the *most likely* N-token string to be output by the LM given the input. )
smth smth tokenizer-agnostic - `loglikelihood_rolling`
- Each request contains `Instance.args : Tuple[str]`, which is an input string to the model whose *entire* loglikelihood, conditioned on purely the EOT token, will be calculated.
- This is used to evaluate *perplexity* on a data distribution.
- It should return `(ll,) : Tuple[float]` , a.k.a. solely the *loglikelihood* of producing each piece of text given no starting input.
3 reqtypes
- greedy_until, and the arguments passed to it
- loglikelihood, and args passed to it To allow a model to be evaluated on all types of tasks, you will need to implement these three types of measurements (note that `loglikelihood_rolling` is a special case of `loglikelihood`). For a reference implementation, check out `lm_eval/models/huggingface.py` !
- loglikelihood_rolling, and args passed to it **Tip: be careful of indexing in loglikelihood!**
LMs take in tokens in position `[0 1 2 ... N]` and output a probability distribution for token position `N+1`. We provide a simplified graphic here, excerpted from `huggingface.py`:
```
# how this all works (illustrated on a causal decoder-only setup):
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# model \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
```
The final token of the target is not passed into the LM, because we want the LM's predictions *up to but not past* that final target token. For more information, check out https://github.com/EleutherAI/lm-evaluation-harness/issues/942 .
## Registration ## Registration
Congrats on implementing your model! Now it's time to test it out. Congrats on implementing your model! Now it's time to test it out.
...@@ -83,7 +104,9 @@ class MyCustomLM(LM): ...@@ -83,7 +104,9 @@ class MyCustomLM(LM):
Using this decorator results in the class being added to an accounting of the usable LM types maintained internally to the library at `lm_eval.api.registry.MODEL_REGISTRY`. See `lm_eval.api.registry` for more detail on what sorts of registries and decorators exist in the library! Using this decorator results in the class being added to an accounting of the usable LM types maintained internally to the library at `lm_eval.api.registry.MODEL_REGISTRY`. See `lm_eval.api.registry` for more detail on what sorts of registries and decorators exist in the library!
## Testing
We also recommend that new model contributions be accompanied by short tests of their 3 core functionalities, at minimum. To see an example of such tests, look at https://github.com/EleutherAI/lm-evaluation-harness/blob/35bdecd379c0cefad6897e67db892f4a6026a128/tests/test_ggml.py .
## Other ## Other
......
...@@ -17,7 +17,7 @@ git checkout -b <task-name> ...@@ -17,7 +17,7 @@ git checkout -b <task-name>
pip install -e ".[dev]" pip install -e ".[dev]"
``` ```
As a concrete example, we'll walk through reimplementing the `gsm8k` benchmark (a *generative* task which requires sampling text from a model) and the `sciq` benchmark. (a *discriminative*, or *multiple choice*, task where the model picks the most likely of several fixed answer choices). In this document, we'll walk through the basics of implementing a static benchmark evaluation in two formats: a *generative* task which requires sampling text from a model, such as [`gsm8k`](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/lm_eval/tasks/gsm8k/gsm8k.yaml), and a *discriminative*, or *multiple choice*, task where the model picks the most likely of several fixed answer choices, such as [`sciq`](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/lm_eval/tasks/sciq/sciq.yaml).
## Creating a YAML file ## Creating a YAML file
...@@ -45,6 +45,16 @@ dataset_name: ... # the dataset configuration to use. Leave `null` if your datas ...@@ -45,6 +45,16 @@ dataset_name: ... # the dataset configuration to use. Leave `null` if your datas
dataset_kwargs: null # any extra keyword arguments that should be passed to the dataset constructor, e.g. `data_dir`. dataset_kwargs: null # any extra keyword arguments that should be passed to the dataset constructor, e.g. `data_dir`.
``` ```
------------------------------
**Tip:** To load a local dataset for evaluation, you can specify data files in the `dataset_kwargs` field, such as the following for JSON files:
```
dataset_path: json
dataset_name: null
dataset_kwargs:
data_files: /path/to/my/json
```
-------------------------------
Next, we'd like to tell our task what the dataset's train, validation, and test splits are named, if they exist: Next, we'd like to tell our task what the dataset's train, validation, and test splits are named, if they exist:
```yaml ```yaml
...@@ -116,7 +126,7 @@ doc_to_choice: ['No', 'Yes'] ...@@ -116,7 +126,7 @@ doc_to_choice: ['No', 'Yes']
We support the [Jinja 2](https://jinja.palletsprojects.com/en/3.1.x/) templating language for writing prompts. In practice, this means you can take your dataset's columns and do many basic string manipulations to place each document into prompted format. We support the [Jinja 2](https://jinja.palletsprojects.com/en/3.1.x/) templating language for writing prompts. In practice, this means you can take your dataset's columns and do many basic string manipulations to place each document into prompted format.
Take for example `super_glue/boolq`, as input, we'd like to use the features `passage` and `question` and string them together so that for a a sample line `doc`, the model sees something the format of: Take for example the dataset `super_glue/boolq`. As input, we'd like to use the features `passage` and `question` and string them together so that for a a sample line `doc`, the model sees something the format of:
``` ```
doc["passage"] doc["passage"]
Question: doc["question"]? Question: doc["question"]?
...@@ -285,7 +295,7 @@ It's now time to check models' performance on your task! In the evaluation harne ...@@ -285,7 +295,7 @@ It's now time to check models' performance on your task! In the evaluation harne
To enable this, we provide a checklist that should be completed when contributing a new task, to enable accurate book-keeping and to ensure that tasks added to the library are well-tested and, where applicable, precedented. To enable this, we provide a checklist that should be completed when contributing a new task, to enable accurate book-keeping and to ensure that tasks added to the library are well-tested and, where applicable, precedented.
### Task impl. checklist ### Task Validity Checklist
The checklist is the following: The checklist is the following:
......
...@@ -20,19 +20,19 @@ Task naming + registration: ...@@ -20,19 +20,19 @@ Task naming + registration:
Dataset configuration options: Dataset configuration options:
- **dataset_path** (`str`) — The name of the dataset as listed by HF in the datasets Hub. - **dataset_path** (`str`) — The name of the dataset as listed by HF in the datasets Hub.
- **dataset_name** (`str`, *optional*, defaults to None) — The name of, what HF calls, a “data instance” or sub-task of the benchmark. If your task does not contain any data instances, just leave this to default to None. (If you're familiar with the HF `datasets.load_dataset` function, these are just the first 2 arguments to it.) - **dataset_name** (`str`, *optional*, defaults to None) — The name of what HF calls a “data instance” or sub-task of the benchmark. If your task does not contain any data instances, just leave this to default to None. (If you're familiar with the HF `datasets.load_dataset` function, these are just the first 2 arguments to it.)
- **dataset_kwargs** (`dict`, *optional*) — Auxiliary arguments that `datasets.load_dataset` accepts. This can be used to specify arguments such as `data_files` or `data_dir` if you want to use local datafiles such as json or csv. - **dataset_kwargs** (`dict`, *optional*) — Auxiliary arguments that `datasets.load_dataset` accepts. This can be used to specify arguments such as `data_files` or `data_dir` if you want to use local datafiles such as json or csv.
- **training_split** (`str`, *optional*) — Split in the dataset to use as the training split. - **training_split** (`str`, *optional*) — Split in the dataset to use as the training split.
- **validation_split** (`str`, *optional*) — Split in the dataset to use as the validation split. - **validation_split** (`str`, *optional*) — Split in the dataset to use as the validation split.
- **test_split** (`str`, *optional*) — Split in the dataset to use as the test split. - **test_split** (`str`, *optional*) — Split in the dataset to use as the test split.
- **fewshot_split** (`str`, *optional*) — Split in the dataset to draw few-shot exemplars from. assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) - **fewshot_split** (`str`, *optional*) — Split in the dataset to draw few-shot exemplars from. assert that this not None if num_fewshot > 0.
- **process_docs** (`Callable`, *optional*) — Optionally define a function to apply to each HF dataset split, to preprocess all documents before being fed into prompt template rendering or other evaluation steps. Can be used to rename dataset columns, or to process documents into a format closer to the expected format expected by a prompt template. - **process_docs** (`Callable`, *optional*) — Optionally define a function to apply to each HF dataset split, to preprocess all documents before being fed into prompt template rendering or other evaluation steps. Can be used to rename dataset columns, or to process documents into a format closer to the expected format expected by a prompt template.
Prompting / in-context formatting options: Prompting / in-context formatting options:
- **use_prompt** (`str`, *optional*) — Name of prompt in promptsource to use. if defined, will overwrite doc_to_text, doc_to_target, and doc_to_choice. - **use_prompt** (`str`, *optional*) — Name of prompt in promptsource to use. if defined, will overwrite doc_to_text, doc_to_target, and doc_to_choice.
- **doc_to_text** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate input for the model - **doc_to_text** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate input for the model
- **doc_to_target** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate target output for the model. For multiple choice tasks, this should return an index into - **doc_to_target** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate target output for the model. For multiple choice tasks, this should return an index into
- **doc_to_choice** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `greedy_until` tasks. - **doc_to_choice** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `generate_until` tasks.
- **fewshot_delimiter** (`str`, *optional*, defaults to "\n\n") — String to insert between few-shot examples. - **fewshot_delimiter** (`str`, *optional*, defaults to "\n\n") — String to insert between few-shot examples.
- **target_delimiter** (`str`, *optional*, defaults to `" "`) — String to insert between input and target output for the datapoint being tested. - **target_delimiter** (`str`, *optional*, defaults to `" "`) — String to insert between input and target output for the datapoint being tested.
...@@ -42,7 +42,7 @@ Runtime configuration options: ...@@ -42,7 +42,7 @@ Runtime configuration options:
Scoring details: Scoring details:
- **metric_list** (`str`, *optional*, defaults to None) — A list of metrics to use for evaluation. See docs for expected format. - **metric_list** (`str`, *optional*, defaults to None) — A list of metrics to use for evaluation. See docs for expected format.
- **output_type** (`str`, *optional*, defaults to "greedy_until") — Selects the type of model output for the given task. Options are `greedy_until`, `loglikelihood`, `loglikelihood_rolling`, and `multiple_choice`. - **output_type** (`str`, *optional*, defaults to "generate_until") — Selects the type of model output for the given task. Options are `generate_until`, `loglikelihood`, `loglikelihood_rolling`, and `multiple_choice`.
- **generation_kwargs** (`dict`, *optional*) — Auxiliary arguments for the `generate` function from HF transformers library. Advanced keyword arguments may not be supported for non-HF LM classes. - **generation_kwargs** (`dict`, *optional*) — Auxiliary arguments for the `generate` function from HF transformers library. Advanced keyword arguments may not be supported for non-HF LM classes.
- **repeats** (`int`, *optional*, defaults to 1) — Number of repeated runs through model for each sample. can be used for cases such as self-consistency. - **repeats** (`int`, *optional*, defaults to 1) — Number of repeated runs through model for each sample. can be used for cases such as self-consistency.
- **filter_list** (`Union[str, list]`, *optional*) — List of filters to postprocess model outputs. See below for further detail on the filter API. - **filter_list** (`Union[str, list]`, *optional*) — List of filters to postprocess model outputs. See below for further detail on the filter API.
......
...@@ -5,3 +5,4 @@ maka ...@@ -5,3 +5,4 @@ maka
mor mor
te te
ond ond
extraversion
...@@ -2,11 +2,10 @@ import os ...@@ -2,11 +2,10 @@ import os
import re import re
import json import json
import fnmatch import fnmatch
import jsonlines
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import numpy as np
from lm_eval import evaluator, utils from lm_eval import evaluator, utils
from lm_eval.api.registry import ALL_TASKS from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger, SPACING from lm_eval.logger import eval_logger, SPACING
...@@ -15,6 +14,15 @@ from lm_eval.tasks import include_path ...@@ -15,6 +14,15 @@ from lm_eval.tasks import include_path
from typing import Union from typing import Union
def _handle_non_serializable(o):
if isinstance(o, np.int64) or isinstance(o, np.int32):
return int(o)
elif isinstance(o, set):
return list(o)
else:
return str(o)
def parse_eval_args() -> argparse.Namespace: def parse_eval_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model", required=True, help="Name of model e.g. `hf`") parser.add_argument("--model", required=True, help="Name of model e.g. `hf`")
...@@ -97,6 +105,12 @@ def parse_eval_args() -> argparse.Namespace: ...@@ -97,6 +105,12 @@ def parse_eval_args() -> argparse.Namespace:
default=None, default=None,
help="Additional path to include if there are external tasks to include.", help="Additional path to include if there are external tasks to include.",
) )
parser.add_argument(
"--verbosity",
type=str,
default="INFO",
help="Log error when tasks are not registered.",
)
return parser.parse_args() return parser.parse_args()
...@@ -105,6 +119,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -105,6 +119,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
# we allow for args to be passed externally, else we parse them ourselves # we allow for args to be passed externally, else we parse them ourselves
args = parse_eval_args() args = parse_eval_args()
eval_logger.setLevel(getattr(logging, f"{args.verbosity}"))
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
if args.limit: if args.limit:
...@@ -112,7 +127,6 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -112,7 +127,6 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
" --limit SHOULD ONLY BE USED FOR TESTING." " --limit SHOULD ONLY BE USED FOR TESTING."
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
) )
if args.include_path is not None: if args.include_path is not None:
eval_logger.info(f"Including path: {args.include_path}") eval_logger.info(f"Including path: {args.include_path}")
include_path(args.include_path) include_path(args.include_path)
...@@ -188,7 +202,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -188,7 +202,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if results is not None: if results is not None:
if args.log_samples: if args.log_samples:
samples = results.pop("samples") samples = results.pop("samples")
dumped = json.dumps(results, indent=2, default=lambda o: str(o)) dumped = json.dumps(results, indent=2, default=_handle_non_serializable)
if args.show_config: if args.show_config:
print(dumped) print(dumped)
...@@ -203,9 +217,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -203,9 +217,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
re.sub("/|=", "__", args.model_args), task_name re.sub("/|=", "__", args.model_args), task_name
) )
filename = path.joinpath(f"{output_name}.jsonl") filename = path.joinpath(f"{output_name}.jsonl")
samples_dumped = json.dumps(
with jsonlines.open(filename, "w") as f: samples[task_name], indent=2, default=_handle_non_serializable
f.write_all(samples[task_name]) )
filename.open("w").write(samples_dumped)
print( print(
f"{args.model} ({args.model_args}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " f"{args.model} ({args.model_args}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
......
...@@ -4,7 +4,7 @@ from typing import Literal, Tuple ...@@ -4,7 +4,7 @@ from typing import Literal, Tuple
@dataclass @dataclass
class Instance: class Instance:
request_type: Literal["loglikelihood", "loglikelihood_rolling", "greedy_until"] request_type: Literal["loglikelihood", "loglikelihood_rolling", "generate_until"]
doc: dict doc: dict
arguments: tuple arguments: tuple
idx: int idx: int
......
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
import sacrebleu import sacrebleu
import sklearn.metrics import sklearn.metrics
import random import random
import evaluate
from lm_eval.api.registry import register_metric, register_aggregation from lm_eval.api.registry import register_metric, register_aggregation
...@@ -105,6 +106,25 @@ def ter(items): ...@@ -105,6 +106,25 @@ def ter(items):
return sacrebleu.corpus_ter(preds, refs).score return sacrebleu.corpus_ter(preds, refs).score
@register_aggregation("brier_score")
def brier_score(items): # This is a passthrough function
gold, predictions = list(zip(*items))
gold = list(gold)
gold_one_hot = np.eye(np.max(gold) + 1)[gold]
predictions = list(zip(*items))[1]
return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1))
@register_metric(
metric="brier_score",
higher_is_better=False,
output_type=["multiple_choice"],
aggregation="brier_score",
)
def brier_score_fn(items): # This is a passthrough function
return items
@register_metric( @register_metric(
metric="acc", metric="acc",
higher_is_better=True, higher_is_better=True,
...@@ -135,6 +155,19 @@ def acc_mutual_info_fn(items): # This is a passthrough function ...@@ -135,6 +155,19 @@ def acc_mutual_info_fn(items): # This is a passthrough function
return items return items
exact_match = evaluate.load("exact_match")
@register_metric(
metric="exact_match",
higher_is_better=True,
output_type="generate_until",
aggregation="mean",
)
def exact_match_fn(**kwargs):
return exact_match.compute(**kwargs)
@register_metric( @register_metric(
metric="perplexity", metric="perplexity",
higher_is_better=False, higher_is_better=False,
...@@ -212,7 +245,7 @@ def f1_fn(items): # This is a passthrough function ...@@ -212,7 +245,7 @@ def f1_fn(items): # This is a passthrough function
@register_metric( @register_metric(
metric="bleu", metric="bleu",
higher_is_better=True, higher_is_better=True,
output_type="greedy_until", output_type="generate_until",
aggregation="bleu", aggregation="bleu",
) )
def bleu_fn(items): # This is a passthrough function def bleu_fn(items): # This is a passthrough function
...@@ -222,7 +255,7 @@ def bleu_fn(items): # This is a passthrough function ...@@ -222,7 +255,7 @@ def bleu_fn(items): # This is a passthrough function
@register_metric( @register_metric(
metric="chrf", metric="chrf",
higher_is_better=True, higher_is_better=True,
output_type="greedy_until", output_type="generate_until",
aggregation="chrf", aggregation="chrf",
) )
def chrf_fn(items): # This is a passthrough function def chrf_fn(items): # This is a passthrough function
...@@ -232,7 +265,7 @@ def chrf_fn(items): # This is a passthrough function ...@@ -232,7 +265,7 @@ def chrf_fn(items): # This is a passthrough function
@register_metric( @register_metric(
metric="ter", metric="ter",
higher_is_better=True, higher_is_better=True,
output_type="greedy_until", output_type="generate_until",
aggregation="ter", aggregation="ter",
) )
def ter_fn(items): # This is a passthrough function def ter_fn(items): # This is a passthrough function
......
...@@ -96,7 +96,7 @@ class LM(abc.ABC): ...@@ -96,7 +96,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length # TODO: Add an optional max length
@abc.abstractmethod @abc.abstractmethod
def greedy_until(self, requests) -> List[str]: def generate_until(self, requests) -> List[str]:
"""Generate greedily until a stopping sequence """Generate greedily until a stopping sequence
:param requests: list[Instance] :param requests: list[Instance]
...@@ -211,12 +211,12 @@ class CachingLM: ...@@ -211,12 +211,12 @@ class CachingLM:
) )
for req in tqdm(requests): for req in tqdm(requests):
hsh = hash_args(attr, req.args) hsh = hash_args(attr, req.args)
if attr == "greedy_until" and req.args[1].get("do_sample", False): if attr == "generate_until" and req.args[1].get("do_sample", False):
# when we are doing non-greedy generation, don't use the cache # when we are doing non-greedy generation, don't use the cache
# (else every "randomly sampled" generation would be identical for repeats > 1). # (else every "randomly sampled" generation would be identical for repeats > 1).
if not warned: if not warned:
eval_logger.warning( eval_logger.warning(
f"Arguments to lm.greedy_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests." f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
) )
warned = True warned = True
res.append(None) res.append(None)
......
...@@ -68,10 +68,10 @@ def register_group(name): ...@@ -68,10 +68,10 @@ def register_group(name):
return decorate return decorate
AGGREGATION_REGISTRY = {}
DEFAULT_AGGREGATION_REGISTRY = {}
METRIC_REGISTRY = {}
OUTPUT_TYPE_REGISTRY = {} OUTPUT_TYPE_REGISTRY = {}
METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY = {}
HIGHER_IS_BETTER_REGISTRY = {} HIGHER_IS_BETTER_REGISTRY = {}
DEFAULT_METRIC_REGISTRY = { DEFAULT_METRIC_REGISTRY = {
...@@ -81,7 +81,7 @@ DEFAULT_METRIC_REGISTRY = { ...@@ -81,7 +81,7 @@ DEFAULT_METRIC_REGISTRY = {
], ],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"], "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": ["acc", "acc_norm"], "multiple_choice": ["acc", "acc_norm"],
"greedy_until": ["exact_match"], "generate_until": ["exact_match"],
} }
...@@ -95,8 +95,7 @@ def register_metric(**args): ...@@ -95,8 +95,7 @@ def register_metric(**args):
for key, registry in [ for key, registry in [
("metric", METRIC_REGISTRY), ("metric", METRIC_REGISTRY),
("higher_is_better", HIGHER_IS_BETTER_REGISTRY), ("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
# ("output_type", OUTPUT_TYPE_REGISTRY), ("aggregation", METRIC_AGGREGATION_REGISTRY),
("aggregation", DEFAULT_AGGREGATION_REGISTRY),
]: ]:
if key in args: if key in args:
...@@ -158,12 +157,13 @@ def get_aggregation(name): ...@@ -158,12 +157,13 @@ def get_aggregation(name):
) )
def get_default_aggregation(metric_name): def get_metric_aggregation(name):
try: try:
return DEFAULT_AGGREGATION_REGISTRY[metric_name] return METRIC_AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
eval_logger.warning( eval_logger.warning(
f"No default aggregation metric for metric '{metric_name}'!" "{} metric is not assigned a default aggregation!".format(name),
) )
...@@ -171,7 +171,6 @@ def is_higher_better(metric_name): ...@@ -171,7 +171,6 @@ def is_higher_better(metric_name):
try: try:
return HIGHER_IS_BETTER_REGISTRY[metric_name] return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError: except KeyError:
raise Warning(f"higher_is_better not specified for metric '{metric_name}'!")
eval_logger.warning( eval_logger.warning(
f"higher_is_better not specified for metric '{metric_name}'!" f"higher_is_better not specified for metric '{metric_name}'!"
) )
...@@ -33,7 +33,7 @@ from lm_eval.api.metrics import ( ...@@ -33,7 +33,7 @@ from lm_eval.api.metrics import (
from lm_eval.api.registry import ( from lm_eval.api.registry import (
get_metric, get_metric,
get_aggregation, get_aggregation,
get_default_aggregation, get_metric_aggregation,
is_higher_better, is_higher_better,
DEFAULT_METRIC_REGISTRY, DEFAULT_METRIC_REGISTRY,
OUTPUT_TYPE_REGISTRY, OUTPUT_TYPE_REGISTRY,
...@@ -44,7 +44,7 @@ ALL_OUTPUT_TYPES = [ ...@@ -44,7 +44,7 @@ ALL_OUTPUT_TYPES = [
"loglikelihood", "loglikelihood",
"multiple_choice", "multiple_choice",
"loglikelihood_rolling", "loglikelihood_rolling",
"greedy_until", "generate_until",
] ]
...@@ -52,7 +52,9 @@ ALL_OUTPUT_TYPES = [ ...@@ -52,7 +52,9 @@ ALL_OUTPUT_TYPES = [
class TaskConfig(dict): class TaskConfig(dict):
# task naming/registry # task naming/registry
task: str = None task: str = None
task_alias: str = None
group: Union[str, list] = None group: Union[str, list] = None
group_alias: Union[str, list] = None
# HF dataset options. # HF dataset options.
# which dataset to use, # which dataset to use,
# and what splits for what purpose # and what splits for what purpose
...@@ -69,7 +71,6 @@ class TaskConfig(dict): ...@@ -69,7 +71,6 @@ class TaskConfig(dict):
doc_to_text: Union[Callable, str] = None doc_to_text: Union[Callable, str] = None
doc_to_target: Union[Callable, str] = None doc_to_target: Union[Callable, str] = None
doc_to_choice: Union[Callable, str, dict, list] = None doc_to_choice: Union[Callable, str, dict, list] = None
gold_alias: Union[Callable, str] = None
process_results: Union[Callable, str] = None process_results: Union[Callable, str] = None
use_prompt: str = None use_prompt: str = None
description: str = "" description: str = ""
...@@ -80,7 +81,7 @@ class TaskConfig(dict): ...@@ -80,7 +81,7 @@ class TaskConfig(dict):
num_fewshot: int = 0 num_fewshot: int = 0
# scoring options # scoring options
metric_list: list = None metric_list: list = None
output_type: str = "greedy_until" output_type: str = "generate_until"
generation_kwargs: dict = None generation_kwargs: dict = None
repeats: int = 1 repeats: int = 1
filter_list: Union[str, list] = None filter_list: Union[str, list] = None
...@@ -97,11 +98,11 @@ class TaskConfig(dict): ...@@ -97,11 +98,11 @@ class TaskConfig(dict):
self.dataset_path = inspect.getfile(import_module(self.dataset_path)) self.dataset_path = inspect.getfile(import_module(self.dataset_path))
if self.generation_kwargs is not None: if self.generation_kwargs is not None:
if self.output_type != "greedy_until": if self.output_type != "generate_until":
eval_logger.warning( eval_logger.warning(
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: greedy_until`!" f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
) )
assert self.output_type != "greedy_until" assert self.output_type != "generate_until"
if "temperature" in self.generation_kwargs: if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float( self.generation_kwargs["temperature"] = float(
...@@ -111,7 +112,7 @@ class TaskConfig(dict): ...@@ -111,7 +112,7 @@ class TaskConfig(dict):
if "until" not in self.generation_kwargs: if "until" not in self.generation_kwargs:
self.generation_kwargs["until"] = [self.fewshot_delimiter] self.generation_kwargs["until"] = [self.fewshot_delimiter]
else: else:
if self.output_type == "greedy_until": if self.output_type == "generate_until":
# ensure that we greedily generate in absence of explicit arguments otherwise # ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = { self.generation_kwargs = {
"until": None "until": None
...@@ -538,12 +539,14 @@ class ConfigurableTask(Task): ...@@ -538,12 +539,14 @@ class ConfigurableTask(Task):
self._aggregation_list = {} self._aggregation_list = {}
self._higher_is_better = {} self._higher_is_better = {}
_metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
if self.config.metric_list is None: if self.config.metric_list is None:
# TODO: handle this in TaskConfig.__post_init__ ? # TODO: handle this in TaskConfig.__post_init__ ?
_metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
for metric_name in _metric_list: for metric_name in _metric_list:
self._metric_fn_list[metric_name] = get_metric(metric_name) self._metric_fn_list[metric_name] = get_metric(metric_name)
self._aggregation_list[metric_name] = get_default_aggregation( self._metric_fn_kwargs[metric_name] = {}
self._aggregation_list[metric_name] = get_metric_aggregation(
metric_name metric_name
) )
self._higher_is_better[metric_name] = is_higher_better(metric_name) self._higher_is_better[metric_name] = is_higher_better(metric_name)
...@@ -586,7 +589,7 @@ class ConfigurableTask(Task): ...@@ -586,7 +589,7 @@ class ConfigurableTask(Task):
] ]
else: else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()} INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = get_default_aggregation(metric_name) metric_agg = get_metric_aggregation(metric_name)
eval_logger.warning( eval_logger.warning(
f"[Task: {self._config.task}] metric {metric_name} is defined, but aggregation is not. " f"[Task: {self._config.task}] metric {metric_name} is defined, but aggregation is not. "
f"using default " f"using default "
...@@ -687,7 +690,10 @@ class ConfigurableTask(Task): ...@@ -687,7 +690,10 @@ class ConfigurableTask(Task):
for choice in check_choices: for choice in check_choices:
choice_has_whitespace = True if choice[0].isspace() else False choice_has_whitespace = True if choice[0].isspace() else False
delimiter_has_whitespace = ( delimiter_has_whitespace = (
True if self.config.target_delimiter[-1].isspace() else False True
if self.config.target_delimiter.rstrip()
!= self.config.target_delimiter
else False
) )
if delimiter_has_whitespace and choice_has_whitespace: if delimiter_has_whitespace and choice_has_whitespace:
...@@ -696,7 +702,7 @@ class ConfigurableTask(Task): ...@@ -696,7 +702,7 @@ class ConfigurableTask(Task):
) )
elif (not delimiter_has_whitespace) and (not choice_has_whitespace): elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
eval_logger.warning( eval_logger.warning(
f'Both target_delimiter and target choice: "{choice}" does not have whitespace, ignore if the language you are evaluating on does not require/use whitespace' f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
) )
def download(self, dataset_kwargs=None) -> None: def download(self, dataset_kwargs=None) -> None:
...@@ -888,26 +894,6 @@ class ConfigurableTask(Task): ...@@ -888,26 +894,6 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def gold_alias(self, doc):
# returns a version of the gold target answer to a document,
# which should be passed into metric for scoring as the ground truth.
# in multiple_choice tasks, this should be castable to an int corresponding to the index
# within the answer choices, while doc_to_target is the string version of {{answer_choices[gold]}}.
if self.config.gold_alias is not None:
doc_to_target = self.config.gold_alias
else:
return self.doc_to_target(doc)
if type(doc_to_target) == str:
return utils.apply_template(doc_to_target, doc)
elif callable(doc_to_target):
return doc_to_target(doc)
elif hasattr(doc_to_target, "apply"):
return doc_to_target.apply(doc)[1]
else:
raise TypeError
def construct_requests( def construct_requests(
self, doc: dict, ctx: str, **kwargs self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]: ) -> Union[List[Instance], Instance]:
...@@ -958,7 +944,7 @@ class ConfigurableTask(Task): ...@@ -958,7 +944,7 @@ class ConfigurableTask(Task):
) )
return request_list return request_list
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, self.config.generation_kwargs) arguments = (ctx, self.config.generation_kwargs)
return Instance( return Instance(
...@@ -1055,12 +1041,21 @@ class ConfigurableTask(Task): ...@@ -1055,12 +1041,21 @@ class ConfigurableTask(Task):
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
exact_match = int(is_greedy[gold]) if gold != -100 else 0 exact_match = int(is_greedy[gold]) if gold != -100 else 0
prob_norm = utils.softmax(lls)
# TODO use keyword arguments to the metric?
# gold, pred, norm stuff, the original lls,
result_dict = { result_dict = {
**({"acc": acc} if "acc" in use_metric else {}), **({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (gold, pred)} if "f1" in use_metric else {}), **({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}), **({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}), **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
**({"exact_match": exact_match} if "exact_match" in use_metric else {}), **({"exact_match": exact_match} if "exact_match" in use_metric else {}),
**(
{"brier_score": (gold, prob_norm)}
if "brier_score" in use_metric
else {}
),
} }
if "acc_mutual_info" in use_metric: if "acc_mutual_info" in use_metric:
...@@ -1070,7 +1065,7 @@ class ConfigurableTask(Task): ...@@ -1070,7 +1065,7 @@ class ConfigurableTask(Task):
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0 acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
result_dict["acc_mutual_info"] = acc_mutual_info result_dict["acc_mutual_info"] = acc_mutual_info
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "generate_until":
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
result = results[0] result = results[0]
if self.config.doc_to_choice is not None: if self.config.doc_to_choice is not None:
...@@ -1134,7 +1129,7 @@ class ConfigurableTask(Task): ...@@ -1134,7 +1129,7 @@ class ConfigurableTask(Task):
else: else:
raise ValueError( raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ", f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until' or 'multiple_choice'", "'loglikelihood', 'loglikelihood_rolling', 'generate_until' or 'multiple_choice'",
) )
return result_dict return result_dict
......
import os
import yaml
from lm_eval import utils
from lm_eval.tasks import register_configurable_task, check_prompt_config
from lm_eval.logger import eval_logger
from lm_eval.api.registry import (
TASK_REGISTRY,
GROUP_REGISTRY,
ALL_TASKS,
)
def include_benchmarks(task_dir: str) -> None:
for root, subdirs, file_list in os.walk(task_dir):
if (subdirs == [] or "__pycache__" in subdirs) and (len(file_list) > 0):
for f in file_list:
if f.endswith(".yaml"):
try:
benchmark_path = os.path.join(root, f)
with open(benchmark_path, "rb") as file:
yaml_config = yaml.full_load(file)
if "prompts" in yaml_config:
continue # Skip it
assert "group" in yaml_config
group = yaml_config["group"]
all_task_list = yaml_config["task"]
config_list = [
task for task in all_task_list if type(task) != str
]
task_list = [
task for task in all_task_list if type(task) == str
]
for task_config in config_list:
yaml_dir = os.path.dirname(benchmark_path)
task_config = utils.load_yaml_config(
yaml_config=task_config, yaml_dir=yaml_dir
)
if "use_prompt" in task_config:
if "yaml" in task_config["use_prompt"]:
task_config["use_prompt"] = os.path.join(
root, task_config["use_prompt"]
)
var_configs = check_prompt_config(
{
**task_config,
**{"group": group},
}
)
for config in var_configs:
register_configurable_task(config)
task_names = utils.pattern_match(task_list, ALL_TASKS)
for task in task_names:
if task in TASK_REGISTRY:
if group in GROUP_REGISTRY:
GROUP_REGISTRY[group].append(task)
else:
GROUP_REGISTRY[group] = [task]
ALL_TASKS.add(group)
except Exception as error:
eval_logger.warning(
"Failed to load benchmark in\n"
f" {benchmark_path}\n"
" Benchmark will not be added to registry\n"
f" Error: {error}"
)
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_benchmarks(task_dir)
...@@ -2,7 +2,6 @@ import random ...@@ -2,7 +2,6 @@ import random
import itertools import itertools
import json import json
import collections import collections
import logging
import sys import sys
import torch import torch
...@@ -25,10 +24,6 @@ from lm_eval.utils import ( ...@@ -25,10 +24,6 @@ from lm_eval.utils import (
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))
@positional_deprecated @positional_deprecated
def simple_evaluate( def simple_evaluate(
...@@ -221,14 +216,15 @@ def evaluate( ...@@ -221,14 +216,15 @@ def evaluate(
task_hierarchy = collections.defaultdict(list) task_hierarchy = collections.defaultdict(list)
# store the ordering of tasks and groups # store the ordering of tasks and groups
task_order = collections.defaultdict(int) task_order = collections.defaultdict(int)
# store the aggregation for aggregating across tasks in the same group task_group_alias = collections.defaultdict(dict)
sample_agg_fn = collections.defaultdict(dict)
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if type(task) == tuple: if type(task) == tuple:
group_name, task = task group_name, task = task
task_hierarchy[group_name].append(task_name) task_hierarchy[group_name].append(task_name)
versions[group_name] = "N/A"
else: else:
task_hierarchy[task_name] = [] task_hierarchy[task_name] = []
...@@ -238,6 +234,14 @@ def evaluate( ...@@ -238,6 +234,14 @@ def evaluate(
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
configs[task_name] = dict(task.dump_config()) configs[task_name] = dict(task.dump_config())
if "task_alias" in configs[task_name]:
task_group_alias[task_name] = configs[task_name]["task_alias"]
if ("group_alias" in configs[task_name]) and (
group_name not in task_group_alias
):
task_group_alias[group_name] = configs[task_name]["group_alias"]
if limit is not None: if limit is not None:
if task.has_test_docs(): if task.has_test_docs():
task_docs = task.test_docs() task_docs = task.test_docs()
...@@ -449,23 +453,8 @@ def evaluate( ...@@ -449,23 +453,8 @@ def evaluate(
group_name = None group_name = None
agg_fn = task.aggregation()[metric] agg_fn = task.aggregation()[metric]
task_score = agg_fn(items) results[task_name][metric_key] = agg_fn(items)
results[task_name]["samples"] = len(items)
if group_name is not None:
sample_metric_key = metric + "(sample agg)," + key
for grouping in task_to_group[task_name]:
if metric_key in results[grouping]:
results[grouping][metric_key].append(task_score)
else:
results[grouping][metric_key] = [task_score]
if sample_metric_key in results[grouping]:
results[grouping][sample_metric_key] += items
else:
results[grouping][sample_metric_key] = items.copy()
sample_agg_fn[grouping][sample_metric_key] = agg_fn
results[task_name][metric_key] = task_score
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
...@@ -481,33 +470,139 @@ def evaluate( ...@@ -481,33 +470,139 @@ def evaluate(
results[task_name][metric + "_stderr" + "," + key] = stderr(items) results[task_name][metric + "_stderr" + "," + key] = stderr(items)
if bool(results): if bool(results):
for task_or_group in results.keys():
for metric in results[task_or_group].keys():
if type(results[task_or_group][metric]) == list:
if "(sample agg)" in metric:
results[task_or_group][metric] = sample_agg_fn[
task_or_group
][metric](results[task_or_group][metric])
else:
results[task_or_group][metric] = np.average(
results[task_or_group][metric]
)
versions[task_or_group] = "N/A"
for task_name, task in task_dict.items(): for group, task_list in reversed(task_hierarchy.items()):
if type(task) == tuple:
group_name, task = task if task_list == []:
total_size = results[group]["samples"]
else:
total_size = 0
for task in task_list:
metrics = results[task]
current_size = metrics.pop("samples")
# TODO: There should be a way for users
# to toggle between weighted and
# unweighted averaging
# For unweighted averaging, use:
# current_size = 1
all_stderr = []
for metric in [
key for key in metrics.keys() if "_stderr" not in key
]:
stderr = "_stderr,".join(metric.split(","))
stderr_score = results[task][stderr]
var_score = stderr_score**2
metric_score = results[task][metric]
all_stderr.append(stderr)
if metric in results[group]:
results[group][metric] = (
results[group][metric] * total_size
+ metric_score * current_size
) / (total_size + current_size)
# $$s_z^2 = \frac{(n-1) s_x^2 + (m-1) s_y^2}{n+m-1} + \frac{nm(\bar x - \bar y)^2}{(n+m)(n+m-1)}.$$
results[group][stderr] = (
(total_size - 1) * results[group][stderr]
+ (current_size - 1) * var_score
) / (
total_size + current_size - 1
) + total_size * current_size / (
(total_size + current_size)
* (total_size + current_size - 1)
) * (
results[group][metric] - metric_score
) ** 2
else:
results[group][metric] = metric_score
results[group][stderr] = var_score
total_size += current_size
for stderr in all_stderr:
results[group][stderr] = np.sqrt(results[group][stderr])
results[group]["samples"] = total_size
def print_tasks(task_hierarchy, task_order, task_version, task_group_alias):
results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict)
for group_name, task_list in task_hierarchy.items():
order = task_order[group_name] order = task_order[group_name]
tabbed_name = "-" * order + group_name results_agg[group_name] = results[group_name].copy()
results_agg[tabbed_name] = results[group_name] results_agg[group_name]["tab"] = order
versions[tabbed_name] = versions[group_name]
if order == 0: if (order < max(task_order.values())) and (len(task_list) > 0):
groups_agg[group_name] = results[group_name] groups_agg[group_name] = results[group_name].copy()
groups_agg[group_name]["tab"] = order
order = task_order[task_name]
tabbed_name = "-" * order + task_name if task_list != []:
results_agg[tabbed_name] = results[task_name] for task in sorted(task_list):
versions[tabbed_name] = versions[task_name] if task in task_hierarchy:
_task_hierarchy = {task: task_hierarchy[task]}
else:
_task_hierarchy = {task: []}
_results_agg, _groups_agg, task_version = print_tasks(
_task_hierarchy, task_order, task_version, task_group_alias
)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
return results_agg, groups_agg, task_version
results_agg, groups_agg, versions = print_tasks(
task_hierarchy, task_order, versions, task_group_alias
)
_results_agg = collections.defaultdict(dict)
_versions = collections.defaultdict(dict)
for task in results_agg:
task_results = results_agg[task]
if "samples" in task_results:
task_results.pop("samples")
tab_string = ""
if "tab" in task_results:
tab = task_results.pop("tab")
tab_string = " " * tab + "- " if tab > 0 else ""
if task in task_group_alias:
task_alias = task_group_alias[task]
_results_agg[tab_string + task_alias] = task_results
_versions[tab_string + task_alias] = versions[task]
else:
_results_agg[tab_string + task] = task_results
_versions[tab_string + task] = versions[task]
results_agg = _results_agg
versions = _versions
_groups_agg = collections.defaultdict(dict)
for group in groups_agg:
group_results = groups_agg[group]
if "samples" in group_results:
group_results.pop("samples")
tab_string = ""
if "tab" in group_results:
tab = group_results.pop("tab")
tab_string = " " * tab + "- " if tab > 0 else ""
if group in task_group_alias:
group_alias = task_group_alias[group]
_groups_agg[tab_string + group_alias] = group_results
else:
_groups_agg[tab_string + group] = group_results
groups_agg = _groups_agg
results_dict = { results_dict = {
"results": dict(results_agg.items()), "results": dict(results_agg.items()),
......
...@@ -138,7 +138,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e ...@@ -138,7 +138,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False): def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
def greedy_until(self, requests) -> List[str]: def generate_until(self, requests) -> List[str]:
if not requests: if not requests:
return [] return []
...@@ -164,7 +164,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e ...@@ -164,7 +164,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
) )
res.append(response) res.append(response)
self.cache_hook.add_partial("greedy_until", request, response) self.cache_hook.add_partial("generate_until", request, response)
except anthropic.APIConnectionError as e: # type: ignore # noqa: F821 except anthropic.APIConnectionError as e: # type: ignore # noqa: F821
eval_logger.critical(f"Server unreachable: {e.__cause__}") eval_logger.critical(f"Server unreachable: {e.__cause__}")
break break
...@@ -179,7 +179,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e ...@@ -179,7 +179,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
raise NotImplementedError() raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until # Isn't used because we override generate_until
raise NotImplementedError() raise NotImplementedError()
def loglikelihood(self, requests): def loglikelihood(self, requests):
......
...@@ -20,7 +20,7 @@ class DummyLM(LM): ...@@ -20,7 +20,7 @@ class DummyLM(LM):
return res return res
def greedy_until(self, requests): def generate_until(self, requests):
res = [] res = []
for ctx, _ in requests: for ctx, _ in requests:
......
...@@ -621,6 +621,23 @@ class HFLM(LM): ...@@ -621,6 +621,23 @@ class HFLM(LM):
return loglikelihoods return loglikelihoods
def _batch_scheduler(self, pos, n_reordered_requests):
sched = pos // int(len(n_reordered_requests) / self.batch_schedule)
if sched in self.batch_sizes:
return self.batch_sizes[sched]
if (len(self.batch_sizes) > 1) and (
self.batch_sizes[sched - 1] == self.max_batch_size
):
# if previous batch size is already maximal, skip recomputation
self.batch_sizes[sched] = self.max_batch_size
return self.batch_sizes[sched]
print(
f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
)
self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos)
print(f"Determined largest batch size: {self.batch_sizes[sched]}")
return self.batch_sizes[sched]
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False, override_bs=None self, requests, disable_tqdm: bool = False, override_bs=None
): ):
...@@ -644,38 +661,22 @@ class HFLM(LM): ...@@ -644,38 +661,22 @@ class HFLM(LM):
# automatic (variable) batch size detection for vectorization # automatic (variable) batch size detection for vectorization
# pull longest context sample from request # pull longest context sample from request
def _batch_scheduler(pos): chunks = utils.chunks(
sched = pos // int(n_reordered_requests / self.batch_schedule) re_ord.get_reordered(),
if sched in self.batch_sizes:
return self.batch_sizes[sched]
if (len(self.batch_sizes) > 1) and (
self.batch_sizes[sched - 1] == self.max_batch_size
):
# if previous batch size is already maximal, skip recomputation
self.batch_sizes[sched] = self.max_batch_size
return self.batch_sizes[sched]
print(
f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
)
self.batch_sizes[sched] = self._detect_batch_size(
re_ord.get_reordered(), pos
)
print(f"Determined largest batch size: {self.batch_sizes[sched]}")
return self.batch_sizes[sched]
for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
n=self.batch_size n=self.batch_size
if self.batch_size != "auto" if self.batch_size != "auto"
else override_bs else override_bs
if override_bs is not None if override_bs is not None
else 0, else 0,
fn=_batch_scheduler fn=self._batch_scheduler
if self.batch_size == "auto" if self.batch_size == "auto"
and n_reordered_requests > 0 and n_reordered_requests > 0
and not override_bs and not override_bs
else None, else None,
): )
pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0)))
for chunk in chunks:
inps = [] inps = []
cont_toks_list = [] cont_toks_list = []
inplens = [] inplens = []
...@@ -812,10 +813,13 @@ class HFLM(LM): ...@@ -812,10 +813,13 @@ class HFLM(LM):
res.append(answer) res.append(answer)
self.cache_hook.add_partial("loglikelihood", cache_key, answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer)
pbar.update(1)
pbar.close()
return re_ord.get_original(res) return re_ord.get_original(res)
def greedy_until(self, requests): def generate_until(self, requests):
res = defaultdict(list) res = defaultdict(list)
re_ords = {} re_ords = {}
...@@ -838,13 +842,26 @@ class HFLM(LM): ...@@ -838,13 +842,26 @@ class HFLM(LM):
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate) re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(total=len(requests), disable=(self.rank != 0))
if self.batch_size == "auto":
# using rolling window with maximum context
print("Passed argument batch_size = auto. Detecting largest batch size")
batch_size = self._detect_batch_size()
print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size
# for each different set of kwargs, we execute all requests, by batch. # for each different set of kwargs, we execute all requests, by batch.
for key, re_ord in re_ords.items(): for key, re_ord in re_ords.items():
for chunk in utils.chunks( chunks = utils.chunks(
re_ord.get_reordered(), re_ord.get_reordered(),
self.batch_size, n=self.batch_size
): if self.batch_size != "auto"
else adaptive_batch_size
if adaptive_batch_size is not None
else 0,
fn=self._batch_scheduler
if self.batch_size == "auto" and not adaptive_batch_size
else None,
)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk) contexts, all_gen_kwargs = zip(*chunk)
# we assume all gen kwargs in the batch are the same # we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it. # this is safe to assume because the `grouper` object ensures it.
...@@ -920,7 +937,7 @@ class HFLM(LM): ...@@ -920,7 +937,7 @@ class HFLM(LM):
res[key].append(s) res[key].append(s)
self.cache_hook.add_partial( self.cache_hook.add_partial(
"greedy_until", (context, gen_kwargs), s "generate_until", (context, gen_kwargs), s
) )
pbar.update(1) pbar.update(1)
# reorder this group of results back to original unsorted form # reorder this group of results back to original unsorted form
......
...@@ -203,7 +203,7 @@ class OpenaiCompletionsLM(LM): ...@@ -203,7 +203,7 @@ class OpenaiCompletionsLM(LM):
self.cache_hook.add_partial("loglikelihood", cache_key, answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return re_ord.get_original(res) return re_ord.get_original(res)
def greedy_until(self, requests) -> List[str]: def generate_until(self, requests) -> List[str]:
if not requests: if not requests:
return [] return []
res = [] res = []
...@@ -260,7 +260,7 @@ class OpenaiCompletionsLM(LM): ...@@ -260,7 +260,7 @@ class OpenaiCompletionsLM(LM):
# partial caching # partial caching
self.cache_hook.add_partial( self.cache_hook.add_partial(
"greedy_until", (context, {"until": until_}), s "generate_until", (context, {"until": until_}), s
) )
res.append(s) res.append(s)
...@@ -271,7 +271,7 @@ class OpenaiCompletionsLM(LM): ...@@ -271,7 +271,7 @@ class OpenaiCompletionsLM(LM):
raise NotImplementedError() raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until # Isn't used because we override generate_until
raise NotImplementedError() raise NotImplementedError()
def loglikelihood_rolling(self, requests) -> List[float]: def loglikelihood_rolling(self, requests) -> List[float]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment