Unverified Commit 5a766ac5 authored by farzanehnakhaee70's avatar farzanehnakhaee70 Committed by GitHub
Browse files

Merge branch 'big-refactor' into big-refactor

parents a0c1dbbd 01cfb2ff
...@@ -7,13 +7,38 @@ ...@@ -7,13 +7,38 @@
This project provides a unified framework to test generative language models on a large number of different evaluation tasks. This project provides a unified framework to test generative language models on a large number of different evaluation tasks.
Features: **Features:**
- 200+ tasks implemented. See the [task-table](./docs/task_table.md) for a complete list. - 200+ tasks implemented. See the [task-table](./docs/task_table.md) for a complete list.
- Support for the Hugging Face `transformers` library, GPT-NeoX, Megatron-DeepSpeed, and the OpenAI API, with flexible tokenization-agnostic interface. - Support for the Hugging Face `transformers` library, GPT-NeoX, Megatron-DeepSpeed, and the OpenAI API, with flexible tokenization-agnostic interface.
- Support for evaluation on adapters (e.g. LoRa) supported in [HuggingFace's PEFT library](https://github.com/huggingface/peft). - Support for evaluation on adapters (e.g. LoRa) supported in [HuggingFace's PEFT library](https://github.com/huggingface/peft).
- Task versioning to ensure reproducibility. - Task versioning to ensure reproducibility.
**Evaluation Overview**
`Task` and `Prompt` classes contain information that, when combined, produces the input to the language model. The language model is then queried to obtain an output. One or more `Filters` can then be applied to perform arbitrary operations on the model's raw output, such as selecting the final answer (for chain of thought) or calling an external API. This final output is then evaluated using a `Metric` to obtain the final result.
```mermaid
graph LR;
classDef empty width:0px,height:0px;
T[Task]
I[Input]
F[Filter]
M[Model]
O[Output]:::empty
P[Prompt]
Me[Metric]
R[Result]
T --- I:::empty
P --- I
I --> M
M --> O
O --> F
Me --> R:::empty
F --> R
```
## Install ## Install
To install `lm-eval` from the github repository main branch, run: To install `lm-eval` from the github repository main branch, run:
......
Tracking progress on revamping documentation pages for the refactor of LM-Evaluation-Harness.
## Desired Pages
* [ ] YAML explainer
* [ ] Explainer on filters + advanced features
* [ ] Walkthrough start-to-finish of adding a new task to codebase
* [ ] Explaining registries + decorators
* [ ] model_guide.md for adding new model API
* [ ] guide to writing an adapter to new advanced codebase (e.g. NeoX)
* [ ] Parallelism guide (?)
\ No newline at end of file
# Advanced Task Configuration
The `lm-evaluation-harness` is meant to be an extensible and flexible framework within which many different evaluation tasks can be defined. All tasks in the new version of the harness are built around a YAML configuration file format.
These YAML configuration files, along with the current codebase commit hash, are intended to be shareable such that providing the YAML config enables another researcher to precisely replicate the evaluation setup used by another, in the case that the prompt or setup differs from standard `lm-eval` task implementations.
While adding a standard evaluation task on a new dataset can be occasionally as simple as swapping out a Hugging Face dataset path in an existing file, more specialized evaluation setups. Here we'll provide a crash course on the more advanced logic implementable in YAML form available to users.
If your intended task relies on features beyond what are described in this guide, we'd love to hear about it! Feel free to open an issue describing the scenario on Github, create a PR to the project with a proposed implementation, or ask in the `#lm-thunderdome` channel on the EleutherAI discord.
## Configurations
Tasks are configured via the `TaskConfig` object. Below, we describe all fields usable within the object, and their role in defining a task.
### Parameters
- **task** (`str`, defaults to None) — name of the task.
- **group** (`str`, *optional*) — name of the task group(s) a task belongs to. Enables one to run all tasks with a specified tag or group name at once.
- **reference** (`str`, *optional*) —
- **dataset_path** (`str`) — The name of the dataset as listed by HF in the datasets Hub.
- **dataset_name** (`str`, *optional*, defaults to None) — The name of, what HF calls, a “data instance” or sub-task of the benchmark. If your task does not contain any data instances, just leave this to default to None. (If you're familiar with the HF `datasets.load_dataset` function, these are just the first 2 arguments to it.)
- **dataset_kwargs** (`dict`, *optional*) — Auxillary arguments that `datasets.load_dataset` accepts. This can be used to specify arguments such as `data_files` or `data_dir` if you want to use local datafiles such as json or csv.
- **training_split** (`str`, *optional*) — Split in the dataset to use as the training split.
- **validation_split** (`str`, *optional*) — Split in the dataset to use as the validation split.
- **test_split** (`str`, *optional*) — Split in the dataset to use as the test split.
- **fewshot_split** (`str`, *optional*) — assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
- **template_aliases** (`str`, *optional*) —
- **aliases**: (`Union[str, list]`, *optional*) —
- **doc_to_text** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate input for the model
- **doc_to_target** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate target output for the model
- **num_fewshot** (`int`, *optional*, defaults to 0) — Number of few-shot examples before the input.
- **batch_size** (`int`, *optional*, defaults to 1) — Batch size.
- **repeats** (`int`, *optional*, defaults to 1) — Number of repeated runs for each sample. can be used for cases such as self-consistency.
- **metric_list** (`str`, *optional*, defaults to None) — A list of metrics to use for evaluation. See docs for expected format.
- **gold_alias** (`str`, *optional*, defaults to None) — if provided, used to generate the reference answer that is scored against. Used in cases where `doc_to_target` should be the "target string" format appended to each example's input for a fewshot exemplar, so doc_to_target is used for fewshot examples, but the input to the metric function as `gold` is from `gold_alias`.
- **output_type** (`str`, *optional*, defaults to "greedy_until") — Selects the type of model output for the given task. Options are `greedy_until`, `loglikelihood`, `loglikelihood_rolling`, and `multiple_choice`.
- **generation_kwargs** (`dict`, *optional*) — Auxiliary arguments for the `generate` function from HF transformers library. Advanced keyword arguments may not be supported for non-HF LM classes.
- **delimiter** (`str`, *optional*, defaults to "\n\n") — String to insert between few-shot examples.
- **filter_list** (`Union[str, list]`, *optional*) — List of filters to postprocess model outputs. See below for further detail on the filter API.
- **should_decontaminate** (`bool`, *optional*, defaults to False) -
- **doc_to_decontamination_query** (`str`, *optional*) —
- **use_prompt** (`str`, *optional*) — Name of prompt in promptsource to use, if defined will overwrite doc_to_text and doc_to_target.
- **metadata** (`str`, *optional*) — An optional field where arbitrary metadata can be passed.
## Filters
Explain: What are filters? What is their place in the pipeline?
A key component of the `lm-evaluation-harness` library is the `Filter` object. In a typical evaluation run of the harness, we take the formatted inputs and run them through our LM, with the appropriate output type (greedy or free-form generation, or loglikelihood-based comparative scoring).
After getting scores or output text from our LM on each `Instance` or document in the dataset, we then need to feed these responses into a metric or scoring function to return scores to a user.
However, certain tasks may require more complex behavior than directly turning over model outputs to a metric function. For example, we may want to post-process our output text by truncating it or extracting a model's answer, we may want to ensemble over multiple "takes" on a different document, et cetera.
**Detailed Aside**:
We do such post-processing by operating on *responses*, which are stored after running an LM on an `Instance` from the task in `Instance.resps`.
`resps` is a `List[str]` for each instance, and we pass a `List[List[<expected return type from model>]]` to our filters that is a list of `[instance.resps for instance in instances]`.
Our filters, after completing a pipeline, must return a `List[<expected return type from model>]` which we then unpack and store each element of in `Instance.filtered_resps` for the corresponding instance. Thus, we take as input a list of returns from our model for each doc, and must return a return from our model *without it being wrapped in a list* for each doc.
**End Aside**
A full list of supported filter operations can be found in `lm_eval/filters/__init__.py`. Contributions of new filter types are welcome!
### Multiple Filter Pipelines
Tasks need not be limited to a single filter pipeline. We enable users to run multiple, distinct, filter pipelines on *the same model outputs* generated in one run on a task.
As a case study, let's look at an implementation of solving the Gsm8k math word problem benchmark in `lm_eval/tasks/gsm8k/gsm8k-cot-self-consistency.yaml`. Here, we are emulating the setup used by [Self-Consistency Improves Chain of Thought Prompting](https://arxiv.org/abs/2203.11171), in which evaluation is performed by generating N chain-of-thought outputs from a model via temperature-based sampling, then selecting the answers output by the model at the end of the chains of thought, then majority voting across all those numeric answers.
Within our YAML file:
```yaml
...
repeats: 64
filter_list:
- name: "score-first"
filter:
- function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
- function: "take_first"
- name: "maj@64"
filter:
- function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
- function: "majority_vote"
- function: "take_first"
- name: "maj@8"
filter:
- function: "take_first_k"
k: 8
- function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
- function: "majority_vote"
- function: "take_first"
```
We are able to provide multiple different filter pipelines, each with their own name and list of filters to apply in sequence.
Our first filter pipeline implements
- applying a regex to the model generations (extracting the number within the phrase "The answer is (number)")
- selecting only the first out of the 64 model answers
Then scoring this single answer.
```yaml
- name: "score-first"
filter:
- function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
- function: "take_first"
```
Our second filter pipeline, "maj@64", does majority voting across all 64 answers via:
- applying the same regex to all responses, to get the numerical answer from the model for each of the 64 responses per problem
- applying majority voting to all responses, which then returns a length-1 `[<majority answer>]` list for each
- taking the first element of this length-1 list, to then score the sole response `<majority answer>` for each document.
```yaml
- name: "maj@64"
filter:
- function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
- function: "majority_vote"
- function: "take_first"
```
Our final filter pipeline, "maj@8", does majority voting across the first 8 of the model's responses per document via:
- subsetting the len-64 list of responses `[answer1, answer2, ..., answer64]` to `[answer1, answer2, ..., answer8]` for each document
- performing the same sequence of filters on these new sets of 8 responses, for each document.
```yaml
- name: "maj@8"
filter:
- function: "take_first_k"
k: 8
- function: "regex"
regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)"
- function: "majority_vote"
- function: "take_first"
```
Thus, given the 64 responses from our LM on each document, we can report metrics on these responses in these 3 different ways, as defined by our filter pipelines.
## Embedded Python Code
Use can use python functions for certain arguments by using the `!function` operator after the argument name followed by `<filename>.<pythonfunctionname>`. This feature can be used for the following arguments:
1. `doc_to_text`
2. `doc_to_target`
3. `gold_alias`
4. `aggregation` for a `metric` in `metric_list`
## (No Longer Recommended) Direct `Task` Subclassing
The prior implementation method of new tasks was to subclass `Task`. While we intend to migrate all tasks to the new YAML implementation option going forward, it remains possible to subclass the Task class and implement custom logic. For more information, see `docs/task_guide.md` in v0.3.0 of the `lm-evaluation-harness`.
## Including a Base YAML
You can base a YAML on another YAML file as a template. This can be handy when you need to just change the prompt for `doc_to_text` but keep the rest the same or change `filters` to compare which is better. Simply use `include` in the YAML file and write the name of the template you want to base from. This assumes that the base temeplate is in the same directory. Otherwise, You will need to define the full path.
```
include: <YAML filename or with full path>
...
```
You can find an example of how to use this feature at [gsm8k-cot-self-consistency.yaml](https://github.com/EleutherAI/lm-evaluation-harness/blob/3c07cc04a92fc467d7c9a94894aeddd58c93a5da/lm_eval/tasks/gsm8k/gsm8k-cot-self-consistency.yaml) where it is based off [gsm8k-cot.yaml](https://github.com/EleutherAI/lm-evaluation-harness/blob/3c07cc04a92fc467d7c9a94894aeddd58c93a5da/lm_eval/tasks/gsm8k/gsm8k-cot.yaml)
## Passing Arguments to Metrics
Metrics can be defined in the `metric_list` argument when building the YAML config. Multiple metrics can be listed along with any auxillary arguments. For example, setting the [`exact_match` metric](https://github.com/huggingface/evaluate/tree/main/metrics/exact_match), auxiliary arguments such as `ignore_case`, `ignore_punctuation`, `regexes_to_ignore` can be listed as well. They will be added to the metric function as `kwargs`. Some metrics have predefined values for `aggregation` and `higher_is_better` so listing the metric name only can be sufficient.
```
metric_list:
- metric: acc
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: false
regexes_to_ignore:
- ","
- "\\$"
```
### Natively Supported Metrics
Here we list all metrics currently supported natively in `lm-eval`:
Metrics:
* `acc` (accuracy)
* `acc_norm` (length-normalized accuracy)
* `acc_mutual_info` (baseline loglikelihood - normalized accuracy)
* `perplexity`
* `word_perplexity` (perplexity per word)
* `byte_perplexity` (perplexity per byte)
* `bits_per_byte`
* `matthews_corrcoef` (Matthews correlation coefficient)
* `f1` (F1 score)
* `bleu`
* `chrf`
* `ter`
Aggregation functions:
* `mean`
* `median`
* `perplexity`
* `weighted_perplexity`
* `bits_per_byte`
## Good Reference Tasks
Contributing a new task can be daunting! Luckily, much of the work has often been done for you in a different, similarly evaluated task. Good examples of task implementations to study include:
Multiple choice tasks:
- SciQ (`lm_eval/tasks/sciq/sciq.yaml`)
Corpus perplexity evaluations:
- Wikitext (`lm_eval/tasks/wikitext/wikitext.yaml`)
Generative tasks:
- GSM8k (`lm_eval/tasks/gsm8k/gsm8k.yaml`)
Tasks using complex filtering:
- GSM8k with CoT (+ with Self-Consistency): (`lm_eval/tasks/gsm8k/gsm8k-cot.yaml` ; `lm_eval/tasks/gsm8k/gsm8k-cot-self-consistency.yaml`)
# New Task Guide
`lm-evaluation-harness` is a framework that strives to support a wide range of zero- and few-shot evaluation tasks on autoregressive language models (LMs).
This documentation page provides a walkthrough to get started creating your own task, on the `big-refactor` branch of the repository (which will be v0.5.0 in the future.)
## Setup
If you haven't already, go ahead and fork the main repo, clone it, create a branch with the name of your task, and install the project requirements in your environment:
```sh
# After forking...
git clone https://github.com/<YOUR-USERNAME>/lm-evaluation-harness.git
cd lm-evaluation-harness
git checkout big-refactor
git checkout -b <task-name>
pip install -e ".[dev]"
```
As a concrete example, we'll walk through reimplementing the `gsm8k` benchmark (a *generative* task which requires sampling text from a model) and the `sciq` benchmark. (a *discriminative*, or *multiple choice*, task where the model picks the most likely of several fixed answer choices).
## Creating a YAML file
- Tasks in eval harness are largely implemented via YAML files.
- mention the tasks worth "forking"/building off of
- Step through the different args all tasks will need
To implement a new standard task, we'll need to write a YAML file which configures our task logic. We start by making a new empty YAML file:
```sh
touch lm_eval/tasks/new_mcqa.yaml
```
or
```sh
touch lm_eval/tasks/new_generative_task.yaml
```
### Selecting and configuring a dataset
All data downloading and management is handled through the HuggingFace (**HF**) [`datasets`](https://github.com/huggingface/datasets) API. So, the first thing you should do is check to see if your task's dataset is already provided in their catalog [here](https://huggingface.co/datasets). If it's not in there, please consider adding it to their Hub to make it accessible to a wider user base by following their [new dataset guide](https://github.com/huggingface/datasets/blob/master/ADD_NEW_DATASET.md)
.
Once you have a HuggingFace dataset prepared for your task, we want to assign our new YAML to use this dataset:
```yaml
dataset_path: ... # the name of the dataset on the HF Hub.
dataset_name: ... # the dataset configuration to use. Leave `null` if your dataset does not require a config to be passed. See https://huggingface.co/docs/datasets/load_hub#configurations for more info.
dataset_kwargs: null # any extra keyword arguments that should be passed to the dataset constructor, e.g. `data_dir`.
```
Next, we'd like to tell our task what the dataset's train, validation, and test splits are named, if they exist:
```yaml
training_split: <split name of training set, or `null`>
validation_split: <split name of val. set, or `null`>
test_split: <split name of test set, or `null`>
```
Tests will run on the `test_split` if it is available, and otherwise evaluate on the `validation_split`.
We can also specify from which split the task should retrieve few-shot examples via:
```yaml
fewshot_split: <split name to draw fewshot examples from, or `null`>
```
though if this is not set, we will default to train/validation/test sets, in that order.
### Writing a prompt with Jinja 2
The next thing we need to do is decide what format to use when presenting the data to the LM. This is our **prompt**, where we'll define both an input and output format.
We support the [Jinja 2](https://jinja.palletsprojects.com/en/3.1.x/) templating language for writing prompts. In practice, this means you can take your dataset's columns and do many basic string manipulations to place each document into prompted format.
To write a prompt, users are required to write two YAML fields in Jinja as strings:
```yaml
doc_to_text:
doc_to_target:
```
Suppose our dataset has a `"question"` field, and an `"answer"` field, which are both strings. We want the model to see, if given a `document` object that is a row of our dataset:
```
Question: {document[question]}
Answer:
```
We do this by writing
```yaml
doc_to_text: "Question: {{question}}\nAnswer:"
```
Such that {{question}} will be replaced by `doc["question"]` when rendering the prompt template.
Our intended output is for the model to predict a single whitespace, and then the answer to the question. We do this via:
```yaml
doc_to_target: "{{answer}}"
```
**Important**: We always add one whitespace between the input and output, such that the full input-output string is `doc_to_target(doc) + " " + doc_to_text(doc)`. doc_to_text and doc_to_target should not contain trailing right or left whitespace, respectively.
Users can also fill out the optional `template_aliases` YAML field, which is added ahead of both the `doc_to_text` and `doc_to_target` fields. This field should not contain any test, but only Jinja variable definitions (`{% ... %}` clauses). This can be used to perform more involved string manipulations and renamings of dataset columns while the main prompt fields remain easy to parse visually.
### Using Python Functions for Prompts
There may be cases where the prompt we want to implement is easier expressed in Python instead of Jinja 2. For this, we can use Python helper functions that are defined in the YAML config. It should be noted that the function script must be in the same directory as the yaml.
A good example is WikiText that requires a lot of regex rules to clean the samples.
```
def wikitext_detokenizer(doc):
string = doc["page"]
# contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
...
string = string.replace(" 's", "'s")
return string
```
We can load this function in `doc_to_target` by using a `!function` operator after `doc_to_target` and followed by `<file name>.<function name>`. In the file [wikitext.yaml](https://github.com/EleutherAI/lm-evaluation-harness/blob/6ae376e3a43caa58b95bb8aa73054a94827bf560/lm_eval/tasks/wikitext/wikitext.yaml) we write:
```
doc_to_target: !function preprocess_wikitext.wikitext_detokenizer
```
### Importing a Prompt from Promptsource
[Promptsource](https://github.com/bigscience-workshop/promptsource/tree/main/promptsource) is a great repository for crowdsourced prompts for many datasets. We can load these prompts easily by using the `use_prompt` argument and filling it with the format `"promptsource:<name of prompt template>"`. To use this, `doc_to_text` and `doc_to_target` should be left undefined. This will fetch the template of the dataset defined in the YAML file.
For example, For Super Glue BoolQ, if we want to use the prompt template `GPT-3 Style` we can add this to the YAML file.
```
use_prompt: "promptsource:GPT-3 Style"
```
#### Multiple choice format
For tasks which are multiple choice (a fixed, finite set of label words per each document) and evaluated via comparing loglikelihoods of all label words (the `multiple_choice` task output type) we enforce a particular convention on prompt format.
An annotated example in the case of SciQ is as follows:
```yaml
template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # `template_aliases` must set the list of possible answer choices to the jinja variable `answer_choices` (List[str]), and set what the index within `answer_choices` of this doc's gold label (correct answer choice).
doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:" # This is the input portion of the prompt for this doc. It will have " {{choice}}" appended to it as target for each choice in answer_choices.
doc_to_target: "{{gold}}" # this must be castable to an integer. It must output only the index within `answer_choices` that is the correct label.
```
Task implementers are thus able to decide what the answer choices should be for a document, and what prompt format to use.
### Setting metrics
You're almost done! Now we need to choose how to score our task.
- *If this is a multiple choice task:* do you just want to check your model's accuracy in choosing the correct answer choice?
- *If this is a generation task:* do you just want to check how often your model outputs *exactly the ground-truth output string provided*?
If the answer to the above is no: you'll need to record what scoring metrics to use! Metrics can be listed in the following format:
```yaml
metric_list:
- metric: <name of the metric here>
aggregation: <name of the aggregation fn here>
higher_is_better: <true or false>
- metric: ...
aggregation: ...
higher_is_better: ...
```
`aggregation` and `higher_is_better` can optionally be left out to default to the manually-set defaults, if using a natively supported metric.
For a full list of natively supported metrics and aggregation functions see `docs/advanced_task_guide.md`. All metrics supported in [HuggingFace Evaluate](https://github.com/huggingface/evaluate/tree/main/metrics) can also be used, and will be loaded if a given metric name is not one natively supported in `lm-eval`.
### Optional, More Advanced Setup
Some tasks may require more advanced processing logic than is described in this guide.
As a heuristic check:
* Does your task require generating multiple free-form outputs per input document?
* Does your task require complex, multi-step post-processing of generated model outputs?
* Does your task require subsetting documents on the fly based on their content?
* Do you expect to compute metrics after applying multiple such processing steps on your model outputs?
* Does your task rely on metrics that need a custom implementation?
For more detail on the task system and advanced features, see `docs/advanced_task_guide.md` . If none of the above sound like they apply to your task, it's time to continue onto checking your task performance!
### Task name + groups (registering a task)
To test a task conveniently, it helps to *register* the task--that is, to give it a name and make the `lm-eval` library aware it exists!
If you're writing your YAML file inside the `lm_eval/tasks` folder, you just need to give your task a name! You can do this inside your YAML file:
```yaml
task: <name of the task>
```
Including a task name is mandatory.
It is often also convenient to label your task with several `groups`, or tags, though this field is optional:
```yaml
group:
- group1
- group2
```
This will add your task to the `group1` and `group2` groups, enabling people to know how to categorize your task, and if desired run all tasks in one of these groups at once, your task along with them.
If your task is not in the `lm_eval/tasks` folder, you'll need to tell the Eval Harness where to look for YAML files.
You can do this via adding the Python snippet
```python
from lm_eval.tasks import include_task_folder
include_task_folder("/path/to/yaml/parent/folder")
```
to the top of any Python file that is run or imported when performing evaluation, such as `main.py`.
Passing `--tasks /path/to/yaml/file` is also accepted.
## Checking validity
After registering your task, you can now check on your data downloading and verify that the few-shot samples look as intended. Run the following command with your desired args:
```bash
python -m scripts.write_out \
--output_base_path <path> \
--tasks <your-task-name> \
--sets <train | val | test> \
--num_fewshot K \
--num_examples N \
```
Open the file specified at the `--output_base_path <path>` and ensure it passes
a simple eye test.
## Checking performance + equivalence
It's now time to check models' performance on your task! In the evaluation harness, we intend to support a wide range of evaluation tasks and setups, but prioritize the inclusion of already-proven benchmarks following the precise evaluation setups in the literature where possible.
To enable this, we provide a checklist that should be completed when contributing a new task, to enable accurate book-keeping and to ensure that tasks added to the library are well-tested and, where applicable, precedented.
### Task impl. checklist
The checklist is the following:
For adding novel benchmarks/datasets to the library:
* [ ] Is the task an existing benchmark in the literature?
* [ ] Has the task been checked for equivalence with the original paper's methodology?
* [ ] Is the task in Eval-harness v0.3.0 or earlier?
* [ ] If so, has it been checked for regression from earlier versions? If there is a change in results, is it justified by matching the original authors' intended setup?
If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
## Submitting your task
You're all set! Now push your work and make a pull request to the `big-refactor` branch! Thanks for the contribution :). If there are any questions, please leave a message in the `#lm-thunderdome` channel on the EAI discord!
This folder is meant to contain instructions and task setups required to evaluate certain papers which may perform non-standard evaluation setups.
Tasks can be supported already in the library under `lm_eval/tasks`, or if highly paper-specific, may remain as YAMLs in the respective `examples/paper-title` folder.
## Verified Papers:
* [WIP] [Chain-of-Thought Prompting Elicits Reasoning in Large Language Models](https://arxiv.org/abs/2201.11903)
* Further details can be found in the `chain_of_thought` subfolder.
## Candidates to Support:
* Least-to-Most Prompting
* Algorithmic Prompting
* Other in-scope prompting techniques
* Multi-turn prompting strategies are likely out of scope for the repository.
* Pythia Suite: Term Frequencies over training
* All setups from GPT-3 Paper
* Varying few-shot orderings + selection ; Varying the label choices for multiple-choice tasks
* Your Paper Here!
\ No newline at end of file
# Chain-of-Thought Prompting Elicits Reasoning in Large Language Models
https://arxiv.org/abs/2201.11903
## All Tasks in Paper
* ...
* ...
* ...
## Reproduction Scripts
* ...
\ No newline at end of file
...@@ -13,7 +13,7 @@ class Filter: ...@@ -13,7 +13,7 @@ class Filter:
""" """
def __init__(self): def __init__(self, *args, **kwargs):
""" """
Can define custom behavior here, if an individual instantiation of a Filter class should have state. Can define custom behavior here, if an individual instantiation of a Filter class should have state.
""" """
...@@ -47,10 +47,7 @@ class FilterEnsemble: ...@@ -47,10 +47,7 @@ class FilterEnsemble:
] # operate just on the model responses ] # operate just on the model responses
for f in self.filters: for f in self.filters:
# apply filters in sequence # apply filters in sequence
out = f.apply(resps) resps = f.apply(resps)
resps = (
out # TODO: handle the case where a filter returns multiple "buckets"
)
# add the end results after filtering to filtered_requests of their respective source instances. # add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name. # has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
......
...@@ -6,96 +6,105 @@ import sacrebleu ...@@ -6,96 +6,105 @@ import sacrebleu
import sklearn.metrics import sklearn.metrics
import random import random
import evaluate from lm_eval.api.registry import register_metric, register_aggregation
AGGREGATION_REGISTRY = {}
METRIC_REGISTRY = {
"acc": None,
"acc_norm": None,
"acc_mutual_info": None,
"word_perplexity": None,
"byte_perplexity": None,
}
HIGHER_IS_BETTER_REGISTRY = {
"matthews_corrcoef": True,
"f1_score": True,
"perplexity": False,
"bleu": True,
"chrf": True,
"ter": False,
"acc": True,
"acc_norm": True,
"acc_mutual_info": True,
"word_perplexity": False,
"byte_perplexity": False,
"bits_per_byte": False,
}
def register_metric(name):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn):
assert (
name not in METRIC_REGISTRY
), f"metric named '{name}' conflicts with existing registered metric!"
METRIC_REGISTRY[name] = fn
return fn
return decorate
def get_metric(name):
try:
return METRIC_REGISTRY[name]
except KeyError:
# TODO: change this print to logging?
print(
f"Could not find registered metric '{name}' in lm-eval, \
searching in HF Evaluate library..."
)
try:
metric_object = evaluate.load(name)
return metric_object.compute
except Exception:
raise Warning(
"{} not found in the evaluate library!".format(name),
"Please check https://huggingface.co/evaluate-metric",
)
def register_aggregation(name):
# TODO: should we enforce a specific interface to aggregation metrics?
def decorate(fn):
assert (
name not in AGGREGATION_REGISTRY
), f"aggregation named '{name}' conflicts with existing registered aggregation!"
AGGREGATION_REGISTRY[name] = fn
return fn
return decorate
def get_aggregation(name):
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
raise Warning(
"{} not a registered aggregation metric!".format(name),
)
# Register Aggregations First
@register_aggregation("mean") @register_aggregation("mean")
def mean(arr): def mean(arr):
return sum(arr) / len(arr) return sum(arr) / len(arr)
@register_aggregation("median")
def median(arr):
return arr[len(arr) // 2]
@register_aggregation("perplexity")
def perplexity(items):
return math.exp(-mean(items))
@register_aggregation("weighted_perplexity")
def weighted_perplexity(items):
return math.exp(-weighted_mean(items))
@register_aggregation("bits_per_byte")
def bits_per_byte(items):
return -weighted_mean(items) / math.log(2)
@register_metric(
metric="acc",
higher_is_better=True,
output_type=["loglikelihood", "multiple_choice"],
aggregation="mean",
)
def acc_fn(items): # This is a passthrough function
return items
@register_metric(
metric="acc_norm",
higher_is_better=True,
output_type=["loglikelihood", "multiple_choice"],
aggregation="mean",
)
def acc_norm_fn(items): # This is a passthrough function
return items
@register_metric(
metric="acc_mutual_info",
higher_is_better=True,
output_type="multiple_choice",
aggregation="mean",
)
def acc_mutual_info_fn(items): # This is a passthrough function
return items
@register_metric(
metric="perplexity",
higher_is_better=False,
output_type="loglikelihood",
aggregation="perplexity",
)
def perplexity_fn(items): # This is a passthrough function
return items
@register_metric(
metric="word_perplexity",
higher_is_better=False,
output_type="loglikelihood_rolling",
aggregation="weighted_perplexity",
)
def word_perplexity_fn(items): # This is a passthrough function
return items
@register_metric(
metric="byte_perplexity",
higher_is_better=False,
output_type="loglikelihood_rolling",
aggregation="weighted_perplexity",
)
def byte_perplexity_fn(items): # This is a passthrough function
return items
@register_metric(
metric="bits_per_byte",
higher_is_better=False,
output_type="loglikelihood_rolling",
aggregation="bits_per_byte",
)
def bits_per_byte_fn(items): # This is a passthrough function
return items
def pop_stddev(arr): def pop_stddev(arr):
mu = mean(arr) mu = mean(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr)) return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
...@@ -110,12 +119,7 @@ def mean_stderr(arr): ...@@ -110,12 +119,7 @@ def mean_stderr(arr):
return sample_stddev(arr) / math.sqrt(len(arr)) return sample_stddev(arr) / math.sqrt(len(arr))
@register_aggregation("median") @register_metric(metric="matthews_corrcoef", higher_is_better=True, aggregation="mean")
def median(arr):
return arr[len(arr) // 2]
@register_metric("matthews_corrcoef")
def matthews_corrcoef(items): def matthews_corrcoef(items):
unzipped_list = list(zip(*items)) unzipped_list = list(zip(*items))
golds = unzipped_list[0] golds = unzipped_list[0]
...@@ -123,7 +127,12 @@ def matthews_corrcoef(items): ...@@ -123,7 +127,12 @@ def matthews_corrcoef(items):
return sklearn.metrics.matthews_corrcoef(golds, preds) return sklearn.metrics.matthews_corrcoef(golds, preds)
@register_metric("f1_score") @register_metric(
metric="f1",
higher_is_better=True,
output_type="multiple_choice",
aggregation="mean",
)
def f1_score(items): def f1_score(items):
unzipped_list = list(zip(*items)) unzipped_list = list(zip(*items))
golds = unzipped_list[0] golds = unzipped_list[0]
...@@ -133,6 +142,12 @@ def f1_score(items): ...@@ -133,6 +142,12 @@ def f1_score(items):
return np.max(fscore) return np.max(fscore)
@register_metric(
metric="acc_all",
higher_is_better=True,
output_type="loglikelihood",
aggregation="mean",
)
def acc_all(items): def acc_all(items):
# Only count as correct if all answers are labeled correctly for each question # Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {} question_scoring_dict = {}
...@@ -179,30 +194,12 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): ...@@ -179,30 +194,12 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(scores_for_ground_truths) return max(scores_for_ground_truths)
@register_metric("perplexity")
@register_aggregation("perplexity")
def perplexity(items):
return math.exp(-mean(items))
def weighted_mean(items): def weighted_mean(items):
a, b = zip(*items) a, b = zip(*items)
return sum(a) / sum(b) return sum(a) / sum(b)
@register_metric("weighted_perplexity") @register_metric(metric="bleu", higher_is_better=True, aggregation="mean")
@register_aggregation("weighted_perplexity")
def weighted_perplexity(items):
return math.exp(-weighted_mean(items))
@register_metric("bits_per_byte")
@register_aggregation("bits_per_byte")
def bits_per_byte(items):
return -weighted_mean(items) / math.log(2)
@register_metric("bleu")
def bleu(items): def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching for evaluating a generated sentence to a reference sentence. It counts matching
...@@ -220,7 +217,7 @@ def bleu(items): ...@@ -220,7 +217,7 @@ def bleu(items):
return sacrebleu.corpus_bleu(preds, refs).score return sacrebleu.corpus_bleu(preds, refs).score
@register_metric("chrf") @register_metric(metric="chrf", higher_is_better=True, aggregation="mean")
def chrf(items): def chrf(items):
"""chrF++ is a tool for automatic evaluation of machine translation output """chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams. based on character n-gram precision and recall enhanced with word n-grams.
...@@ -235,7 +232,7 @@ def chrf(items): ...@@ -235,7 +232,7 @@ def chrf(items):
return sacrebleu.corpus_chrf(preds, refs).score return sacrebleu.corpus_chrf(preds, refs).score
@register_metric("ter") @register_metric(metric="ter", higher_is_better=True, aggregation="mean")
def ter(items): def ter(items):
"""Translation Error Rate is an error metric for machine translation that """Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one measures the number of edits required to change a system output into one
......
...@@ -4,32 +4,6 @@ from typing import Union ...@@ -4,32 +4,6 @@ from typing import Union
from lm_eval import utils from lm_eval import utils
MODEL_REGISTRY = {}
def register_model(*names):
# either pass a list or a single alias.
# function receives them as a tuple of strings
def decorate(cls):
for name in names:
assert issubclass(
cls, LM
), f"Model '{name}' ({cls.__name__}) must extend LM class"
assert (
name not in MODEL_REGISTRY
), f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
MODEL_REGISTRY[name] = cls
return cls
return decorate
def get_model(model_name):
return MODEL_REGISTRY[model_name]
class LM(abc.ABC): class LM(abc.ABC):
def __init__(self): def __init__(self):
......
import os
task_registry = {}
group_registry = {}
task2func_index = {}
func2task_index = {}
def register_task(name):
def wrapper(func):
task_registry[name] = func
func2task_index[func.__name__] = name
task2func_index[name] = func.__name__
return func
return wrapper
def register_group(name):
def wrapper(func):
func_name = func2task_index[func.__name__]
if name in group_registry:
group_registry[name].append(func_name)
else:
group_registry[name] = [func_name]
return func
return wrapper
import os
import evaluate
from lm_eval.api.model import LM
MODEL_REGISTRY = {}
def register_model(*names):
# either pass a list or a single alias.
# function receives them as a tuple of strings
def decorate(cls):
for name in names:
assert issubclass(
cls, LM
), f"Model '{name}' ({cls.__name__}) must extend LM class"
assert (
name not in MODEL_REGISTRY
), f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
MODEL_REGISTRY[name] = cls
return cls
return decorate
def get_model(model_name):
return MODEL_REGISTRY[model_name]
TASK_REGISTRY = {}
GROUP_REGISTRY = {}
ALL_TASKS = []
func2task_index = {}
def register_task(name):
def decorate(fn):
assert (
name not in TASK_REGISTRY
), f"task named '{name}' conflicts with existing registered task!"
TASK_REGISTRY[name] = fn
func2task_index[fn.__name__] = name
return fn
return decorate
def register_group(name):
def decorate(fn):
func_name = func2task_index[fn.__name__]
if name in GROUP_REGISTRY:
GROUP_REGISTRY[name].append(func_name)
else:
GROUP_REGISTRY[name] = [func_name]
return fn
return decorate
AGGREGATION_REGISTRY = {}
DEFAULT_AGGREGATION_REGISTRY = {}
METRIC_REGISTRY = {}
OUTPUT_TYPE_REGISTRY = {}
HIGHER_IS_BETTER_REGISTRY = {}
DEFAULT_METRIC_REGISTRY = {
"loglikelihood": [
"perplexity",
"acc",
],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": [
"acc",
"acc_norm"
],
"greedy_until": ["exact_match"],
}
def register_metric(**args):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn):
assert "metric" in args
name = args["metric"]
for key, registry in [
("metric", METRIC_REGISTRY),
("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
# ("output_type", OUTPUT_TYPE_REGISTRY),
("aggregation", DEFAULT_AGGREGATION_REGISTRY),
]:
if key in args:
value = args[key]
assert (
value not in registry
), f"{key} named '{value}' conflicts with existing registered {key}!"
if key == "metric":
registry[name] = fn
elif key == "aggregation":
registry[name] = AGGREGATION_REGISTRY[value]
else:
registry[name] = value
return fn
return decorate
def get_metric(name):
try:
return METRIC_REGISTRY[name]
except KeyError:
# TODO: change this print to logging?
print(
f"Could not find registered metric '{name}' in lm-eval, \
searching in HF Evaluate library..."
)
try:
metric_object = evaluate.load(name)
return metric_object.compute
except Exception:
raise Warning(
"{} not found in the evaluate library!".format(name),
"Please check https://huggingface.co/evaluate-metric",
)
def register_aggregation(name):
# TODO: should we enforce a specific interface to aggregation metrics?
def decorate(fn):
assert (
name not in AGGREGATION_REGISTRY
), f"aggregation named '{name}' conflicts with existing registered aggregation!"
AGGREGATION_REGISTRY[name] = fn
return fn
return decorate
def get_aggregation(name):
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
raise Warning(
"{} not a registered aggregation metric!".format(name),
)
import abc import abc
from dataclasses import dataclass from dataclasses import dataclass, field, asdict
import re import re
import ast import ast
...@@ -18,43 +18,55 @@ from collections.abc import Callable ...@@ -18,43 +18,55 @@ from collections.abc import Callable
from lm_eval import utils from lm_eval import utils
from lm_eval.api import samplers from lm_eval.api import samplers
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.filter import FilterEnsemble
from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble
from lm_eval.api.metrics import ( from lm_eval.api.metrics import (
METRIC_REGISTRY, # get_metric,
AGGREGATION_REGISTRY, # get_aggregation,
HIGHER_IS_BETTER_REGISTRY,
get_metric,
get_aggregation,
mean, mean,
weighted_perplexity, weighted_perplexity,
bits_per_byte, bits_per_byte,
) )
from lm_eval.api.registry import (
METRIC_REGISTRY,
DEFAULT_METRIC_REGISTRY,
OUTPUT_TYPE_REGISTRY,
AGGREGATION_REGISTRY,
HIGHER_IS_BETTER_REGISTRY,
DEFAULT_AGGREGATION_REGISTRY,
)
from lm_eval.logger import eval_logger ALL_OUTPUT_TYPES = [
from lm_eval.prompts import get_prompt "loglikelihood",
from lm_eval.filters import build_filter_ensemble "multiple_choice",
"loglikelihood_rolling",
"greedy_until",
]
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
task: str = None task: str = None
group: str = None group: Union[str, list] = None
names: str = None
reference: str = None reference: str = None
task_name: str = (
None # TODO: deprecate this, it'll be set in __post_init__ to be names[0]
)
base_task: str = None
dataset_path: str = None dataset_path: str = None
dataset_name: str = None dataset_name: str = None
dataset_kwargs: dict = None
training_split: str = None training_split: str = None
validation_split: str = None validation_split: str = None
test_split: str = None test_split: str = None
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
template_aliases: str = None template_aliases: str = None
aliases: Union[str, list] = None
doc_to_text: Union[Callable, str] = None doc_to_text: Union[Callable, str] = None
doc_to_target: Union[Callable, str] = None doc_to_target: Union[Callable, str] = None
use_prompt: str = None
num_fewshot: int = 0 num_fewshot: int = 0
batch_size: int = 1 batch_size: int = 1
...@@ -63,14 +75,11 @@ class TaskConfig(dict): ...@@ -63,14 +75,11 @@ class TaskConfig(dict):
metric_list: str = None metric_list: str = None
gold_alias: str = None gold_alias: str = None
output_type: str = "greedy_until" output_type: str = "greedy_until"
generation_kwargs: dict = None
delimiter: str = "\n\n" delimiter: str = "\n\n"
filter_list: Union[str, list] = None filter_list: Union[str, list] = None
normalization: str = (
None # TODO: add length-normalization of various types, mutual info
)
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: str = None doc_to_decontamination_query: str = None
use_prompt: str = None
metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
...@@ -85,13 +94,20 @@ class TaskConfig(dict): ...@@ -85,13 +94,20 @@ class TaskConfig(dict):
if type(self.doc_to_target) == str: if type(self.doc_to_target) == str:
self.doc_to_target = self.template_aliases + self.doc_to_target self.doc_to_target = self.template_aliases + self.doc_to_target
# set "task_name" metadata field based on the "primary" name set if type(self.gold_alias) == str:
if self.names: self.gold_alias = self.template_aliases + self.doc_to_target
self.task_name = self.names[0]
if self.generation_kwargs or self.output_type == "greedy_until":
assert self.output_type == "greedy_until", "passed `generation_kwargs`, but not using a generation request type!"
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {"do_sample": False, "temperature": 0.0}
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
def to_dict(self):
return asdict(self)
class Task(abc.ABC): class Task(abc.ABC):
"""A task represents an entire benchmark including its dataset, problems, """A task represents an entire benchmark including its dataset, problems,
...@@ -243,7 +259,7 @@ class Task(abc.ABC): ...@@ -243,7 +259,7 @@ class Task(abc.ABC):
else: else:
eval_logger.warning( eval_logger.warning(
"has_training_docs and has_validation_docs are False" "has_training_docs and has_validation_docs are False"
"using test_docs but this is not recommended." ", using test_docs but this is not recommended."
) )
return self.test_docs() return self.test_docs()
...@@ -308,7 +324,7 @@ class Task(abc.ABC): ...@@ -308,7 +324,7 @@ class Task(abc.ABC):
inst = self.construct_requests( inst = self.construct_requests(
doc=doc, doc=doc,
ctx=fewshot_ctx, ctx=fewshot_ctx,
metadata=(self._config["task_name"], doc_id, self._config.repeats), metadata=(self._config["task"], doc_id, self._config.repeats),
) )
if not isinstance(inst, list): if not isinstance(inst, list):
...@@ -436,13 +452,27 @@ class Task(abc.ABC): ...@@ -436,13 +452,27 @@ class Task(abc.ABC):
def apply_filters(self): def apply_filters(self):
if hasattr(self, "_filters"):
for f in self._filters: for f in self._filters:
f.apply(self._instances) f.apply(self._instances)
else:
eval_logger.warning("No filter defined, passing through instances")
return self._instances
def dump_config(self):
"""Returns a dictionary representing the task's config.
:returns: str
The fewshot context.
"""
# TODO: this should only return the overrides applied to a non-YAML task's configuration.
# (batch size, num_fewshot)
return self._config.to_dict()
class ConfigurableTask(Task): class ConfigurableTask(Task):
VERSION = "2.0" VERSION = "Yaml"
OUTPUT_TYPE = None OUTPUT_TYPE = None
CONFIG = None CONFIG = None
...@@ -466,6 +496,7 @@ class ConfigurableTask(Task): ...@@ -466,6 +496,7 @@ class ConfigurableTask(Task):
) )
if self._config.output_type is not None: if self._config.output_type is not None:
assert self._config.output_type in ALL_OUTPUT_TYPES
self.OUTPUT_TYPE = self._config.output_type self.OUTPUT_TYPE = self._config.output_type
if self._config.dataset_path is not None: if self._config.dataset_path is not None:
...@@ -474,35 +505,42 @@ class ConfigurableTask(Task): ...@@ -474,35 +505,42 @@ class ConfigurableTask(Task):
if self._config.dataset_name is not None: if self._config.dataset_name is not None:
self.DATASET_NAME = self._config.dataset_name self.DATASET_NAME = self._config.dataset_name
if self._config.metric_list is not None: self._metric_fn_list = {}
self._metric_list = {} self._metric_fn_kwargs = {}
self._metric_kwargs = {}
self._aggregation_list = {} self._aggregation_list = {}
self._higher_is_better = {} self._higher_is_better = {}
for metric_config in self._config.metric_list:
_metric_list = DEFAULT_METRIC_REGISTRY[self._config.output_type]
if self._config.metric_list is None:
# TODO: handle this in TaskConfig.__post_init__ ?
for metric_name in _metric_list:
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
self._aggregation_list[metric_name] = DEFAULT_AGGREGATION_REGISTRY[
metric_name
]
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
]
else:
for metric_config in self._config.metric_list:
assert "metric" in metric_config
metric_name = metric_config["metric"] metric_name = metric_config["metric"]
aggregation = metric_config["aggregation"]
higher_is_better = metric_config["higher_is_better"]
kwargs = { kwargs = {
key: metric_config[key] key: metric_config[key]
for key in metric_config for key in metric_config
if key not in ["metric", "aggregation", "higher_is_better"] if key not in ["metric", "aggregation", "higher_is_better"]
} }
try:
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[aggregation] self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
except:
if metric_name in METRIC_REGISTRY.keys(): eval_logger.warning(
self._metric_list[metric_name] = METRIC_REGISTRY[metric_name] f"Metric {metric_name} not found, "
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[ "Searching from https://huggingface.co/evaluate-metric"
metric_name )
]
else:
self._higher_is_better[metric_name] = higher_is_better
try: try:
metric_object = evaluate.load(metric_name) metric_object = evaluate.load(metric_name)
self._metric_list[metric_name] = metric_object self._metric_fn_list[metric_name] = metric_object
self._metric_kwargs[metric_name] = kwargs self._metric_fn_kwargs[metric_name] = kwargs
except Exception: except Exception:
raise Warning( raise Warning(
...@@ -510,12 +548,37 @@ class ConfigurableTask(Task): ...@@ -510,12 +548,37 @@ class ConfigurableTask(Task):
"Please check https://huggingface.co/evaluate-metric", "Please check https://huggingface.co/evaluate-metric",
) )
self.download(data_dir, cache_dir, download_mode) if "aggregation" in metric_config:
agg_name = metric_config["aggregation"]
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[agg_name]
else:
eval_logger.warning(
f"metric {metric_name} is defined, but aggregation is not"
f"using default aggregation for {metric_name}"
)
self._aggregation_list[metric_name] = DEFAULT_AGGREGATION_REGISTRY[
metric_name
]
if "higher_is_better" in metric_config:
self._higher_is_better[metric_name] = metric_config[
"higher_is_better"
]
else:
eval_logger.warning(
f"metric {metric_name} is defined, but higher_is_better is not"
f"using default higher_is_better for {metric_name}"
)
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
]
self.download(self._config.dataset_kwargs)
self._training_docs = None self._training_docs = None
self._fewshot_docs = None self._fewshot_docs = None
self._filters = []
if self._config.filter_list is not None: if self._config.filter_list is not None:
self._filters = []
for filter_config in self._config.filter_list: for filter_config in self._config.filter_list:
for filter_pipeline in filter_config: for filter_pipeline in filter_config:
filter_name = filter_config["name"] filter_name = filter_config["name"]
...@@ -526,12 +589,11 @@ class ConfigurableTask(Task): ...@@ -526,12 +589,11 @@ class ConfigurableTask(Task):
key: function[key] for key in function if key != "function" key: function[key] for key in function if key != "function"
} }
components.append([function["function"], kwargs]) components.append([function["function"], kwargs])
filter_pipeline = build_filter_ensemble(filter_name, components) filter_pipeline = build_filter_ensemble(filter_name, components)
self._filters.append(filter_pipeline) self._filters.append(filter_pipeline)
else: else:
self._filters = [ self._filters = [
build_filter_ensemble("take_first", [["take_first", None]]) build_filter_ensemble("none", [["take_first", None]])
] ]
if self._config.use_prompt is not None: if self._config.use_prompt is not None:
...@@ -545,7 +607,15 @@ class ConfigurableTask(Task): ...@@ -545,7 +607,15 @@ class ConfigurableTask(Task):
if self.fewshot_docs() is not None: if self.fewshot_docs() is not None:
self.sampler = samplers.Sampler( self.sampler = samplers.Sampler(
list(self.fewshot_docs()), self, rnd=random.Random() list(self.fewshot_docs()), self, rnd=random.Random()
) # TODO: pass the correct docs in here )
def download(self, dataset_kwargs=None):
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
**dataset_kwargs if dataset_kwargs is not None else {},
)
def has_training_docs(self): def has_training_docs(self):
if self._config.training_split is not None: if self._config.training_split is not None:
...@@ -578,16 +648,16 @@ class ConfigurableTask(Task): ...@@ -578,16 +648,16 @@ class ConfigurableTask(Task):
return self.dataset[self._config.test_split] return self.dataset[self._config.test_split]
def fewshot_docs(self): def fewshot_docs(self):
if (self._config.num_fewshot > 0) and (self._config.fewshot_split is None): if self._config.fewshot_split is not None:
return self.dataset[self._config.fewshot_split]
else:
if self._config.num_fewshot > 0:
eval_logger.warning( eval_logger.warning(
"num_fewshot > 0 but fewshot_split is None. " "num_fewshot > 0 but fewshot_split is None. "
"using preconfigured rule." "using preconfigured rule."
) )
return super().fewshot_docs() return super().fewshot_docs()
elif self._config.fewshot_split is not None:
return self.dataset[self._config.fewshot_split]
def should_decontaminate(self): def should_decontaminate(self):
return self._config.should_decontaminate return self._config.should_decontaminate
...@@ -639,6 +709,23 @@ class ConfigurableTask(Task): ...@@ -639,6 +709,23 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def gold_alias(self, doc):
# TODO: reevaluate if we need this. implemented to have a
# processed version of answer to put into gsm8k exact_match scoring as ref.
if self._config.gold_alias:
doc_to_target = self._config.gold_alias
else:
doc_to_target = self._config.doc_to_target
if type(doc_to_target) == str:
return utils.apply_template(doc_to_target, doc)
elif callable(doc_to_target):
return doc_to_target(doc)
elif hasattr(doc_to_target, "apply"):
return doc_to_target.apply(doc)[1]
else:
raise TypeError
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
if self.OUTPUT_TYPE == "loglikelihood": if self.OUTPUT_TYPE == "loglikelihood":
...@@ -664,7 +751,7 @@ class ConfigurableTask(Task): ...@@ -664,7 +751,7 @@ class ConfigurableTask(Task):
for i, choice in enumerate(choices) for i, choice in enumerate(choices)
] ]
# TODO: we should raise a warning telling users this will at most ~2x runtime. # TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_list.keys(): if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy # if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls. # using mutual information instead of raw loglikelihood as metric, need unconditional lls.
...@@ -686,7 +773,7 @@ class ConfigurableTask(Task): ...@@ -686,7 +773,7 @@ class ConfigurableTask(Task):
return request_list return request_list
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
arguments = (ctx, self._config.delimiter) arguments = (ctx, self._config.generation_kwargs)
return Instance( return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
...@@ -694,25 +781,44 @@ class ConfigurableTask(Task): ...@@ -694,25 +781,44 @@ class ConfigurableTask(Task):
def process_results(self, doc, results): def process_results(self, doc, results):
# if callable(self._config.process_results):
# return self._config.process_results(doc, results)
result_dict = {} result_dict = {}
use_metric = list(self._metric_fn_list.keys())
if self.OUTPUT_TYPE == "loglikelihood": if self.OUTPUT_TYPE == "loglikelihood":
results = results[0] results = results[0]
ll, is_greedy = results ll, is_greedy = results
result_dict = {"perplexity": ll, "accuracy": int(is_greedy)} return {
**({"perplexity": ll} if "perplexity" in use_metric else {}),
**({"acc": int(is_greedy)} if "acc" in use_metric else {}),
}
elif self.OUTPUT_TYPE == "loglikelihood_rolling": elif self.OUTPUT_TYPE == "loglikelihood_rolling":
(loglikelihood,) = results (loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc)) _words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(self.doc_to_target(doc)) _bytes = self.count_bytes(self.doc_to_target(doc))
return { return {
"word_perplexity": (loglikelihood, words), **(
"byte_perplexity": (loglikelihood, bytes_), {"word_perplexity": (loglikelihood, _words)}
"bits_per_byte": (loglikelihood, bytes_), if "word_perplexity" in use_metric
else {}
),
**(
{"byte_perplexity": (loglikelihood, _bytes)}
if "byte_perplexity" in use_metric
else {}
),
**(
{"bits_per_byte": (loglikelihood, _bytes)}
if "bits_per_byte" in use_metric
else {}
),
} }
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
lls = [
res[0] for res in results lls, is_greedy = zip(*results)
] # only retain loglikelihoods, discard is_greedy
gold = int(self.doc_to_target(doc)) gold = int(self.doc_to_target(doc))
pred = np.argmax(lls)
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
choices = ast.literal_eval( choices = ast.literal_eval(
utils.apply_template( utils.apply_template(
...@@ -721,7 +827,7 @@ class ConfigurableTask(Task): ...@@ -721,7 +827,7 @@ class ConfigurableTask(Task):
) )
if ( if (
2 * len(choices) == len(lls) 2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_list.keys() and "acc_mutual_info" in self._metric_fn_list.keys()
): ):
# then we are doing mutual info. # then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods # this stores the "dryrun" / unconditional answer loglikelihoods
...@@ -735,21 +841,18 @@ class ConfigurableTask(Task): ...@@ -735,21 +841,18 @@ class ConfigurableTask(Task):
acc_norm = 1.0 if np.argmax(lls / completion_len) == gold else 0.0 acc_norm = 1.0 if np.argmax(lls / completion_len) == gold else 0.0
result_dict = { result_dict = {
"acc": acc, **({"acc": acc} if "acc" in use_metric else {}),
"acc_norm": acc_norm, **({"f1": (pred, gold)} if "f1" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
} }
# TODO: set which normalization metrics should be reported, and calculate them # TODO: set which normalization metrics should be reported, and calculate them
if "exact_match" in self._metric_fn_list.keys():
if "exact_match" in self._metric_list.keys():
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
is_greedy = [
res[1] for res in results
] # take only the `is_greedy` results
is_greedy = is_greedy[gold] # take value for the gold answer is_greedy = is_greedy[gold] # take value for the gold answer
result_dict["exact_match"] = int(is_greedy) result_dict["exact_match"] = int(is_greedy)
if "acc_mutual_info" in self._metric_list.keys(): if "acc_mutual_info" in use_metric:
lls_mutual_info = [ lls_mutual_info = [
ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional) ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
] ]
...@@ -759,20 +862,20 @@ class ConfigurableTask(Task): ...@@ -759,20 +862,20 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
if self._config.gold_alias is not None: if self._config.gold_alias is not None:
gold = doc[self._config.gold_alias] gold = self.gold_alias(doc)
else: else:
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
for key, result in zip(self._metric_list.keys(), results): for key, result in zip(self._metric_fn_list.keys(), results):
_dict = self._metric_list[key].compute( _dict = self._metric_fn_list[key].compute(
references=[gold], predictions=[result], **self._metric_kwargs[key] references=[gold], predictions=[result], **self._metric_kwargs[key]
) )
result_dict[key] = _dict[key] result_dict = {**result_dict, **_dict}
else: else:
raise ValueError( raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ", f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until'", "'loglikelihood', 'loglikelihood_rolling', 'greedy_until', or 'multiple_choice'",
) )
return result_dict return result_dict
......
...@@ -10,10 +10,10 @@ import torch ...@@ -10,10 +10,10 @@ import torch
import numpy as np import numpy as np
import lm_eval.api import lm_eval.api
import lm_eval.api.metrics
import lm_eval.tasks import lm_eval.tasks
import lm_eval.models import lm_eval.models
import lm_eval.api.metrics
import lm_eval.api.registry
from lm_eval.utils import ( from lm_eval.utils import (
positional_deprecated, positional_deprecated,
...@@ -79,7 +79,7 @@ def simple_evaluate( ...@@ -79,7 +79,7 @@ def simple_evaluate(
if isinstance(model, str): if isinstance(model, str):
if model_args is None: if model_args is None:
model_args = "" model_args = ""
lm = lm_eval.api.model.get_model(model).create_from_arg_string( lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
model_args, {"batch_size": batch_size, "device": device} model_args, {"batch_size": batch_size, "device": device}
) )
else: else:
...@@ -148,15 +148,16 @@ def evaluate( ...@@ -148,15 +148,16 @@ def evaluate(
results = collections.defaultdict(dict) results = collections.defaultdict(dict)
versions = collections.defaultdict(dict) versions = collections.defaultdict(dict)
configs = collections.defaultdict(dict)
requests = collections.defaultdict(list) requests = collections.defaultdict(list)
# requests_origin = collections.defaultdict(list)
# docs = {} # docs = {}
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
configs[task_name] = dict(task.dump_config()) # TODO: don't access a private attribute here ; for non-YAML tasks handle this case
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
# task_docs = list(task_doc_func()) # task_docs = list(task_doc_func())
...@@ -292,9 +293,7 @@ def evaluate( ...@@ -292,9 +293,7 @@ def evaluate(
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
results[task_name][metric + " - filter=" + key] = task.aggregation()[ results[task_name][metric + "," + key] = task.aggregation()[metric](items)
metric
](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
...@@ -307,11 +306,9 @@ def evaluate( ...@@ -307,11 +306,9 @@ def evaluate(
) )
if stderr is not None: if stderr is not None:
results[task_name][metric + " - filter=" + key + "_stderr"] = stderr( results[task_name][metric + "_stderr" + "," + key] = stderr(items)
items
)
return {"results": dict(results), "versions": dict(versions)} return {"results": dict(results), "configs": dict(configs), "versions": dict(versions)}
else: else:
return None return None
...@@ -6,6 +6,8 @@ from . import extraction ...@@ -6,6 +6,8 @@ from . import extraction
FILTER_REGISTRY = { FILTER_REGISTRY = {
"take_first": selection.TakeFirstFilter, "take_first": selection.TakeFirstFilter,
"regex": extraction.RegexFilter, "regex": extraction.RegexFilter,
"majority_vote": selection.MajorityVoteFilter,
"take_first_k": selection.TakeKFilter,
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function # TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward, # that takes an input and returns a scalar and then should select the max reward,
# or should implement different filters for different ways of handling a reward model's inference. # or should implement different filters for different ways of handling a reward model's inference.
......
...@@ -26,8 +26,6 @@ class RegexFilter(Filter): ...@@ -26,8 +26,6 @@ class RegexFilter(Filter):
match = self.regex.search(resp) match = self.regex.search(resp)
if match: if match:
match = match.group(1).strip() match = match.group(1).strip()
match.replace(",", "")
# TODO: should we assume any other filtering is performed?
else: else:
match = self.fallback match = self.fallback
filtered.append(match) filtered.append(match)
......
from collections import Counter
from lm_eval.api.filter import Filter from lm_eval.api.filter import Filter
class TakeFirstFilter: class TakeFirstFilter(Filter):
def __init__(self): def __init__(self):
""" """
Can define custom behavior here, if an individual instantiation of a Filter class should have state. Can define custom behavior here, if an individual instantiation of a Filter class should have state.
...@@ -12,3 +14,38 @@ class TakeFirstFilter: ...@@ -12,3 +14,38 @@ class TakeFirstFilter:
Assuming each entry of `resps` is a list of model responses, we discard all but the first response. Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
""" """
return map(lambda r: r[0], resps) return map(lambda r: r[0], resps)
class TakeKFilter(Filter):
def __init__(self, *args, **kwargs):
self.k = kwargs.pop("k")
super().__init__(*args, **kwargs)
def apply(self, resps):
# check we have at least k responses per doc, else we can't take the first k
assert (
len(resps[0]) >= self.k
), f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
return map(lambda r: r[: self.k], resps)
class MajorityVoteFilter(Filter):
def __init__(self):
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
def apply(self, resps):
"""
Each entry of `resps` is a list of model responses.
We select the response that occurs most frequently in each entry of `resps`.
"""
def select_majority(resp):
counts = Counter(resp)
vote = counts.most_common(1)[0][0]
return vote
return map(lambda r: [select_majority(r)], resps)
from . import gpt2 from . import hf_causal
from . import gpt3 from . import openai_completions
from . import textsynth from . import textsynth
from . import dummy from . import dummy
......
import random import random
from lm_eval.api.model import LM, register_model from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
@register_model("dummy") @register_model("dummy")
......
import torch import torch
import transformers import transformers
import copy
from tqdm import tqdm from tqdm import tqdm
import torch.nn.functional as F import torch.nn.functional as F
from lm_eval import utils from lm_eval import utils
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
from lm_eval.api.model import LM, register_model from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from accelerate import Accelerator from accelerate import Accelerator
from itertools import islice from itertools import islice
@register_model("hf-causal", "gpt2") @register_model("hf-causal")
class HFLM(LM): class HFLM(LM):
def __init__( def __init__(
self, self,
...@@ -37,10 +39,10 @@ class HFLM(LM): ...@@ -37,10 +39,10 @@ class HFLM(LM):
if device not in ["cuda", "cpu"]: if device not in ["cuda", "cpu"]:
device = int(device) device = int(device)
self._device = torch.device(device) self._device = torch.device(device)
print(f"Using device '{device}'") eval_logger.info(f"Using device '{device}'")
else: else:
print("Device not specified") eval_logger.info("Device not specified")
print(f"Cuda Available? {torch.cuda.is_available()}") eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
self._device = ( self._device = (
torch.device("cuda") torch.device("cuda")
if torch.cuda.is_available() if torch.cuda.is_available()
...@@ -55,10 +57,10 @@ class HFLM(LM): ...@@ -55,10 +57,10 @@ class HFLM(LM):
# TODO: update this to be less of a hack once subfolder is fixed in HF # TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "") revision = revision + ("/" + subfolder if subfolder is not None else "")
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained( self.model = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage
).to(self.device) ).to(self.device)
self.gpt2.eval() self.model.eval()
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, pretrained if tokenizer is None else tokenizer,
...@@ -74,22 +76,21 @@ class HFLM(LM): ...@@ -74,22 +76,21 @@ class HFLM(LM):
if gpus > 1: if gpus > 1:
accelerator = Accelerator() accelerator = Accelerator()
if gpus > accelerator.num_processes: if gpus > accelerator.num_processes:
warning = ( eval_logger.warning(
"WARNING: The number of total system GPUs does not match the number of spawned processes. " "WARNING: The number of total system GPUs does not match the number of spawned processes. "
"If you would like to use data parallelism, please launch the script " "If you would like to use data parallelism, please launch the script "
"with 'accelerate launch *script*'. " "with 'accelerate launch *script*'. "
f"Current run will proceed with {accelerator.num_processes} devices." f"Current run will proceed with {accelerator.num_processes} devices."
) )
print(warning)
self._rank = accelerator.local_process_index self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes self._world_size = accelerator.num_processes
else: else:
self.gpt2 = accelerator.prepare(self.gpt2) self.model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}") self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.accelerator = accelerator self.accelerator = accelerator
if self.accelerator.is_local_main_process: if self.accelerator.is_local_main_process:
print(f"Using {gpus} devices with data parallelism") eval_logger.info(f"Using {gpus} devices with data parallelism")
self._rank = self.accelerator.local_process_index self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes self._world_size = self.accelerator.num_processes
...@@ -103,17 +104,17 @@ class HFLM(LM): ...@@ -103,17 +104,17 @@ class HFLM(LM):
def max_length(self): def max_length(self):
try: try:
if hasattr(self, "accelerator"): if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.gpt2).config.n_ctx return self.accelerator.unwrap_model(self.model).config.n_ctx
else: else:
return self.gpt2.config.n_ctx return self.model.config.n_ctx
except AttributeError: except AttributeError:
# gptneoconfig doesn't have n_ctx apparently # gptneoconfig doesn't have n_ctx apparently
if hasattr(self, "accelerator"): if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model( return self.accelerator.unwrap_model(
self.gpt2 self.model
).config.max_position_embeddings ).config.max_position_embeddings
else: else:
return self.gpt2.config.max_position_embeddings return self.model.config.max_position_embeddings
@property @property
def max_gen_toks(self): def max_gen_toks(self):
...@@ -150,15 +151,28 @@ class HFLM(LM): ...@@ -150,15 +151,28 @@ class HFLM(LM):
logits returned from the model logits returned from the model
""" """
with torch.no_grad(): with torch.no_grad():
return self.gpt2(inps)[0] return self.model(inps)[0]
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id, **generation_kwargs):
return self.gpt2.generate( # we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
if "do_sample" not in generation_kwargs.keys():
generation_kwargs["do_sample"] = False
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.model).generate(
context,
max_length=max_length,
pad_token_id=eos_token_id,
eos_token_id=eos_token_id,
**generation_kwargs,
)
else:
return self.model.generate(
context, context,
max_length=max_length, max_length=max_length,
pad_token_id=eos_token_id, pad_token_id=eos_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
do_sample=False, **generation_kwargs,
) )
def loglikelihood(self, requests): def loglikelihood(self, requests):
...@@ -267,7 +281,7 @@ class HFLM(LM): ...@@ -267,7 +281,7 @@ class HFLM(LM):
# how this all works: # how this all works:
# CTX CONT # CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \ # model \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
...@@ -347,18 +361,44 @@ class HFLM(LM): ...@@ -347,18 +361,44 @@ class HFLM(LM):
re_ord = utils.Reorderer([req.args for req in requests], _collate) re_ord = utils.Reorderer([req.args for req in requests], _collate)
for context, until in tqdm(re_ord.get_reordered()): for context, gen_kwargs in tqdm(re_ord.get_reordered()):
if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in gen_kwargs.keys():
until = gen_kwargs.pop("until")
if isinstance(until, str): if isinstance(until, str):
until = [until] until = [gen_kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in gen_kwargs.keys():
max_gen_toks = gen_kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
try:
(primary_until,) = self.tok_encode(until[0]) (primary_until,) = self.tok_encode(until[0])
except Exception:
# if our primary until would be multiple tokens long, we'll have errors.
# TODO: handling this better will let us stop generating earlier + often.
primary_until = self.eot_token_id
context_enc = torch.tensor( context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]] [self.tok_encode(context)[max_gen_toks - self.max_length :]]
).to(self.device) ).to(self.device)
cont = self._model_generate( cont = self._model_generate(
context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
eos_token_id=primary_until,
**gen_kwargs,
) )
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :]) s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment