Commit ac50adb5 authored by lintangsutawika's avatar lintangsutawika
Browse files

merged with latest big-refactor

parents 6355d06f a3252ed7
name: Tasks Modified # name: Tasks Modified
on: # on:
push: # push:
branches: # branches:
- 'big-refactor*' # - 'big-refactor*'
pull_request: # pull_request:
branches: # branches:
- 'big-refactor*' # - 'big-refactor*'
workflow_dispatch: # workflow_dispatch:
# comment/edit out the above to stop/change the triggers # # comment/edit out the above to stop/change the triggers
jobs: # jobs:
changed_files: # changed_files:
runs-on: ubuntu-latest # windows-latest || macos-latest # runs-on: ubuntu-latest # windows-latest || macos-latest
timeout-minutes: 120 # timeout-minutes: 120
name: Scan for changed tasks # name: Scan for changed tasks
steps: # steps:
- name: checkout # - name: checkout
uses: actions/checkout@v3 # uses: actions/checkout@v3
with: # with:
fetch-depth: 2 # OR "2" -> To retrieve the preceding commit. # fetch-depth: 2 # OR "2" -> To retrieve the preceding commit.
# Uses the tj-actions/changed-files@v37 action to check for changes. # # Uses the tj-actions/changed-files@v37 action to check for changes.
# Outputs provided here: https://github.com/tj-actions/changed-files#outputs # # Outputs provided here: https://github.com/tj-actions/changed-files#outputs
# The `files_yaml` input optionally takes a yaml string to specify filters, # # The `files_yaml` input optionally takes a yaml string to specify filters,
# and prepends the filter name to the standard output names. # # and prepends the filter name to the standard output names.
- name: Check task folders # - name: Check task folders
id: changed-tasks # id: changed-tasks
uses: tj-actions/changed-files@v37.1.2 # uses: tj-actions/changed-files@v37.1.2
with: # with:
# tasks checks the tasks folder and api checks the api folder for changes # # tasks checks the tasks folder and api checks the api folder for changes
files_yaml: | # files_yaml: |
tasks: # tasks:
- lm_eval/tasks/** # - lm_eval/tasks/**
api: # api:
- lm_eval/api/** # - lm_eval/api/**
write_output_files: true # write_output_files: true
# The next step is optional; the files are written to the workspace by default (above). # # The next step is optional; the files are written to the workspace by default (above).
# so it's just for debugging # # so it's just for debugging
- name: Run Tests # - name: Run Tests
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true' # if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
run: | # run: |
echo .github/outputs/tasks_all_changed_and_modified_files.txt >> 'GITHUB_ENV' # echo .github/outputs/tasks_all_changed_and_modified_files.txt >> 'GITHUB_ENV'
echo "One or more test file(s) has changed." # echo "One or more test file(s) has changed."
echo "List of all the files that have changed: ${{ steps.changed-tasks.outputs.tasks_all_modified_files }}" # echo "List of all the files that have changed: ${{ steps.changed-tasks.outputs.tasks_all_modified_files }}"
- name: Set up Python 3.9 # - name: Set up Python 3.9
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true' # if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
uses: actions/setup-python@v4 # uses: actions/setup-python@v4
with: # with:
python-version: 3.9 # python-version: 3.9
cache: 'pip' # cache: 'pip'
cache-dependency-path: setup.py # cache-dependency-path: setup.py
- name: Install dependencies # - name: Install dependencies
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true' # if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
run: | # run: |
python -m pip install --upgrade pip # python -m pip install --upgrade pip
pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu # pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies # # Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt # # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi # # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest # - name: Test with pytest
# if new tasks are added, run tests on them # # if new tasks are added, run tests on them
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' # if: steps.changed-tasks.outputs.tasks_any_modified == 'true'
run: python -m pytest tests/test_tasks.py -s -vv -n=auto # run: python -m pytest tests/test_tasks.py -s -vv
# if api is modified, run tests on it # # if api is modified, run tests on it
- name: Test more tasks with pytest # - name: Test more tasks with pytest
env: # env:
API: true # API: true
if: steps.changed-tasks.outputs.api_any_modified == 'true' # if: steps.changed-tasks.outputs.api_any_modified == 'true'
run: python -m pytest tests/test_tasks.py -s -vv -n=auto # run: python -m pytest tests/test_tasks.py -s -vv
...@@ -40,39 +40,38 @@ jobs: ...@@ -40,39 +40,38 @@ jobs:
flake8 . --count --select=F,E9,E71,E72,E501,E112,E113,W6 --extend-ignore=F541 --show-source --statistics --exit-zero flake8 . --count --select=F,E9,E71,E72,E501,E112,E113,W6 --extend-ignore=F541 --show-source --statistics --exit-zero
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# mypy turned off for now # # mypy turned off for now
# - name: Lint with mypy # - name: Lint with mypy
# run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable # run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable
# Job 2 # Job 2
testcpu: # testcpu:
name: CPU Tests # name: CPU Tests
runs-on: ubuntu-latest # runs-on: ubuntu-latest
strategy: # strategy:
matrix: # matrix:
python-version: [ "3.9", "3.10", "3.11" ] # python-version: [ "3.8", "3.9", "3.10", "3.11" ]
timeout-minutes: 30 # timeout-minutes: 30
# steps:
steps: # - name: Checkout Code
- name: Checkout Code # uses: actions/checkout@v3
uses: actions/checkout@v3 # - name: Set up Python ${{ matrix.python-version }}
- name: Set up Python ${{ matrix.python-version }} # uses: actions/setup-python@v4
uses: actions/setup-python@v4 # with:
with: # python-version: ${{ matrix.python-version }}
python-version: ${{ matrix.python-version }} # cache: pip
cache: pip # cache-dependency-path: setup.py
cache-dependency-path: setup.py # - name: Install dependencies
- name: Install dependencies # run: |
run: | # python -m pip install --upgrade pip
python -m pip install --upgrade pip # pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu # # Install optional git dependencies
# Install optional git dependencies # # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt # # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi # - name: Test with pytest
- name: Test with pytest # run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/tests_master --ignore=tests/extra
run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/tests_master --ignore=tests/extra # - name: Archive artifacts
- name: Archive artifacts # uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v3 # with:
with: # name: output_results
name: output_results # path: |
path: | # test_logs/*
test_logs/*
...@@ -4,6 +4,7 @@ Welcome to the docs for the LM Evaluation Harness! ...@@ -4,6 +4,7 @@ Welcome to the docs for the LM Evaluation Harness!
## Table of Contents ## Table of Contents
* To learn about the public interface of the library, as well as how to evaluate via the commandline or as integrated into an external library, see the [Interface](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/user_guide.md)
* To learn how to add a new library, API, or model type to the library, as well as a quick explainer on the types of ways to evaluate an LM, see the [Model Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/model_guide.md). * To learn how to add a new library, API, or model type to the library, as well as a quick explainer on the types of ways to evaluate an LM, see the [Model Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/model_guide.md).
* For a crash course on adding new tasks to the library, see our [New Task Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/new_task_guide.md). * For a crash course on adding new tasks to the library, see our [New Task Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/new_task_guide.md).
* To learn more about pushing the limits of task configuration that the Eval Harness supports, see the [Advanced Task Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/advanced_task_guide.md). * To learn more about pushing the limits of task configuration that the Eval Harness supports, see the [Advanced Task Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/advanced_task_guide.md).
......
# User Guide
This document details the interface exposed by `lm-eval` and provides details on what flags are available to users.
## Command-line Interface
A majority of users run the library by cloning it from Github and running the `main.py` script.
Equivalently, running the library can be done via the `lm-eval` entrypoint at the command line.
This mode supports a number of command-line arguments, the details of which can be also be seen via running with `-h` or `--help`:
* `--model` : Selects which model type or provider is evaluated. Must be a string corresponding to the name of the model type/provider being used. See [the main README](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor#commercial-apis) for a full list of enabled model names and supported libraries or APIs.
* `--model_args` : Controls parameters passed to the model constructor. Accepts a string containing comma-separated keyword arguments to the model class of the format `"arg1=val1,arg2=val2,..."`, such as, for example `--model_args pretrained=EleutherAI/pythia-160m,dtype=float32`. For a full list of what keyword arguments, see the initialization of the `lm_eval.api.model.LM` subclass, e.g. [`HFLM`](https://github.com/EleutherAI/lm-evaluation-harness/blob/365fcda9b85bbb6e0572d91976b8daf409164500/lm_eval/models/huggingface.py#L66)
* `--tasks` : Determines which tasks or task groups are evaluated. Accepts a comma-separated list of task names or task group names. Must be solely comprised of valid tasks/groups.
* `--num_fewshot` : Sets the number of few-shot examples to place in context. Must be an integer.
* `--batch_size` : Sets the batch size used for evaluation. Can be a positive integer or `"auto"` to automatically select the largest batch size that will fit in memory, speeding up evaluation. One can pass `--batch_size auto:N` to re-select the maximum batch size `N` times during evaluation. This can help accelerate evaluation further, since `lm-eval` sorts documents in descending order of context length.
* `--max_batch_size` : Sets the maximum batch size to try to fit in memory, if `--batch_size auto` is passed.
* `--device` : Sets which device to place the model onto. Must be a string, for example, `"cuda", "cuda:0", "cpu", "mps"`. Defaults to "cuda", and can be ignored if running multi-GPU or running a non-local model type.
* `--output_path` : A string of the form `dir/file.jsonl` or `dir/`. Provides a path where high-level results will be saved, either into the file named or into the directory named. If `--log_samples` is passed as well, then per-document outputs and metrics will be saved into the directory as well.
* `--log_samples` : If this flag is passed, then the model's outputs, and the text fed into the model, will be saved at per-document granularity. Must be used with `--output_path`.
* `--limit` : Accepts an integer, or a float between 0.0 and 1.0 . If passed, will limit the number of documents to evaluate to the first X documents (if an integer) per task or first X% of documents per task. Useful for debugging, especially on costly API models.
* `--use_cache` : Should be a path where a sqlite db file can be written to. Takes a string of format `/path/to/sqlite_cache_` in order to create a cache db at `/path/to/sqlite_cache_rank{i}.db` for each process (0-NUM_GPUS). This allows results of prior runs to be cached, so that there is no need to re-run results in order to re-score or re-run a given (model, task) pair again.
* `--decontamination_ngrams_path` : Deprecated, see (this commit)[https://github.com/EleutherAI/lm-evaluation-harness/commit/00209e10f6e27edf5d766145afaf894079b5fe10] or older for a working decontamination-checker tool.
* `--check_integrity` : If this flag is used, the library tests for each task selected are run to confirm task integrity.
* `--write_out` : Used for diagnostic purposes to observe the format of task documents passed to a model. If this flag is used, then prints the prompt and gold target string for the first document of each task.
* `--show_config` : If used, prints the full `lm_eval.api.task.TaskConfig` contents (non-default settings the task YAML file) for each task which was run, at the completion of an evaluation. Useful for when one is modifying a task's configuration YAML locally to transmit the exact configurations used for debugging or for reproducibility purposes.
* `--include_path` : Accepts a path to a folder. If passed, then all YAML files containing `lm-eval`` compatible task configurations will be added to the task registry as available tasks. Used for when one is writing config files for their own task in a folder other than `lm_eval/tasks/`
## External Library Usage
We also support using the library's external API for use within model training loops or other scripts.
`lm_eval` supplies two functions for external import and use: `lm_eval.evaluate()` and `lm_eval.simple_evaluate()`.
`simple_evaluate()` can be used by simply creating an `lm_eval.api.model.LM` subclass that implements the methods described in the [Model Guide](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor/docs/model_guide.md), and wrapping your custom model in that class as follows:
```python
import lm_eval
...
my_model = initialize_my_model() # create your model (could be running finetuning with some custom modeling code)
...
lm_obj = Your_LM(model=my_model, batch_size=16) # instantiate an LM subclass that takes your initialized model and can run `Your_LM.loglikelihood()`, `Your_LM.loglikelihood_rolling()`, `Your_LM.greedy_until()`
results = lm_eval.simple_evaluate( # call simple_evaluate
model=lm_obj,
tasks=["taskname1", "taskname2"],
num_fewshot=0,
...
)
```
See https://github.com/EleutherAI/lm-evaluation-harness/blob/365fcda9b85bbb6e0572d91976b8daf409164500/lm_eval/evaluator.py#L35 for a full description of all arguments available. All keyword arguments to simple_evaluate share the same role as the command-line flags described previously.
Additionally, the `evaluate()` function offers the core evaluation functionality provided by the library, but without some of the special handling and simplification + abstraction provided by `simple_evaluate()`.
See https://github.com/EleutherAI/lm-evaluation-harness/blob/365fcda9b85bbb6e0572d91976b8daf409164500/lm_eval/evaluator.py#L173 for more details.
As a brief example usage of `evaluate()`:
```python
import lm_eval
from my_tasks import MyTask1 # suppose you've defined a custom lm_eval.api.Task subclass in your own external codebase
...
my_model = initialize_my_model() # create your model (could be running finetuning with some custom modeling code)
...
lm_obj = Your_LM(model=my_model, batch_size=16) # instantiate an LM subclass that takes your initialized model and can run `Your_LM.loglikelihood()`, `Your_LM.loglikelihood_rolling()`, `Your_LM.greedy_until()`
def evaluate(
lm=lm_obj,
task_dict={"mytask1": MyTask1},
...
):
```
import abc import abc
import os import os
from typing import Union, List, Tuple import torch
from typing import Union, List, Tuple, Optional, Type, TypeVar
from sqlitedict import SqliteDict from sqlitedict import SqliteDict
import json import json
import hashlib import hashlib
...@@ -11,6 +12,8 @@ from tqdm import tqdm ...@@ -11,6 +12,8 @@ from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
T = TypeVar("T", bound="LM")
class LM(abc.ABC): class LM(abc.ABC):
def __init__(self) -> None: def __init__(self) -> None:
...@@ -111,11 +114,28 @@ class LM(abc.ABC): ...@@ -111,11 +114,28 @@ class LM(abc.ABC):
pass pass
@classmethod @classmethod
def create_from_arg_string(cls, arg_string, additional_config=None): def create_from_arg_string(
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
) -> T:
"""
Creates an instance of the LM class using the given argument string and additional config.
Parameters:
- arg_string: A string containing arguments in the format key1=value1,key2=value2.
- additional_config: Optional dictionary containing additional configuration parameters.
Returns:
- Instance of the LM class.
"""
additional_config = {} if additional_config is None else additional_config additional_config = {} if additional_config is None else additional_config
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None} args2 = {k: v for k, v in additional_config.items() if v is not None}
if args2.get("device") == "mps" or args.get("device") == "mps": # TODO: delete once float16 MPS is fixed in torch stable
if (
args2.get("device") in ("mps", "mps:0")
or args.get("device") in ("mps", "mps:0")
and "dev" not in torch.__version__
):
args["dtype"] = "float32" args["dtype"] = "float32"
return cls(**args, **args2) return cls(**args, **args2)
......
...@@ -250,6 +250,11 @@ class Task(abc.ABC): ...@@ -250,6 +250,11 @@ class Task(abc.ABC):
download_mode=download_mode, download_mode=download_mode,
) )
@property
def config(self):
"""Returns the TaskConfig associated with this class."""
return self._config
@abc.abstractmethod @abc.abstractmethod
def has_training_docs(self): def has_training_docs(self):
"""Whether the task has a training set""" """Whether the task has a training set"""
...@@ -352,7 +357,7 @@ class Task(abc.ABC): ...@@ -352,7 +357,7 @@ class Task(abc.ABC):
), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!" ), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
eval_logger.info( eval_logger.info(
f"Building contexts for task '{self._config.task}' on rank {rank}..." f"Building contexts for task '{self.config.task}' on rank {rank}..."
) )
instances = [] instances = []
...@@ -362,14 +367,14 @@ class Task(abc.ABC): ...@@ -362,14 +367,14 @@ class Task(abc.ABC):
# sample fewshot context #TODO: need to offset doc_id by rank now! # sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx = self.fewshot_context( fewshot_ctx = self.fewshot_context(
doc, doc,
self._config.num_fewshot, self.config.num_fewshot,
) )
# TODO: we should override self._config.repeats if doing greedy gen so users don't waste time+compute # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
inst = self.construct_requests( inst = self.construct_requests(
doc=doc, doc=doc,
ctx=fewshot_ctx, ctx=fewshot_ctx,
metadata=(self._config["task"], doc_id, self._config.repeats), metadata=(self.config["task"], doc_id, self.config.repeats),
) )
if not isinstance(inst, list): if not isinstance(inst, list):
...@@ -457,9 +462,9 @@ class Task(abc.ABC): ...@@ -457,9 +462,9 @@ class Task(abc.ABC):
if num_fewshot == 0: if num_fewshot == 0:
# always prepend the (possibly empty) task description # always prepend the (possibly empty) task description
labeled_examples = self._config.description labeled_examples = self.config.description
else: else:
labeled_examples = self._config.description + self.sampler.get_context( labeled_examples = self.config.description + self.sampler.get_context(
doc, num_fewshot doc, num_fewshot
) )
...@@ -469,7 +474,7 @@ class Task(abc.ABC): ...@@ -469,7 +474,7 @@ class Task(abc.ABC):
elif type(example) == list: elif type(example) == list:
return [labeled_examples + ex for ex in example] return [labeled_examples + ex for ex in example]
elif type(example) == int: elif type(example) == int:
if self._config.doc_to_choice is not None: if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
return labeled_examples + choices[example] return labeled_examples + choices[example]
else: else:
...@@ -491,7 +496,7 @@ class Task(abc.ABC): ...@@ -491,7 +496,7 @@ class Task(abc.ABC):
""" """
# TODO: this should only return the overrides applied to a non-YAML task's configuration. # TODO: this should only return the overrides applied to a non-YAML task's configuration.
# (num_fewshot) # (num_fewshot)
return self._config.to_dict() return self.config.to_dict()
class ConfigurableTask(Task): class ConfigurableTask(Task):
...@@ -506,35 +511,35 @@ class ConfigurableTask(Task): ...@@ -506,35 +511,35 @@ class ConfigurableTask(Task):
self._config = self.CONFIG self._config = self.CONFIG
# Use new configurations if there was no preconfiguration # Use new configurations if there was no preconfiguration
if self._config is None: if self.config is None:
self._config = TaskConfig(**config) self._config = TaskConfig(**config)
# Overwrite configs # Overwrite configs
else: else:
if config is not None: if config is not None:
self._config.__dict__.update(config) self._config.__dict__.update(config)
if self._config is None: if self.config is None:
raise ValueError( raise ValueError(
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg" "Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
) )
if self._config.output_type is not None: if self.config.output_type is not None:
assert self._config.output_type in ALL_OUTPUT_TYPES assert self.config.output_type in ALL_OUTPUT_TYPES
self.OUTPUT_TYPE = self._config.output_type self.OUTPUT_TYPE = self.config.output_type
if self._config.dataset_path is not None: if self.config.dataset_path is not None:
self.DATASET_PATH = self._config.dataset_path self.DATASET_PATH = self.config.dataset_path
if self._config.dataset_name is not None: if self.config.dataset_name is not None:
self.DATASET_NAME = self._config.dataset_name self.DATASET_NAME = self.config.dataset_name
self._metric_fn_list = {} self._metric_fn_list = {}
self._metric_fn_kwargs = {} self._metric_fn_kwargs = {}
self._aggregation_list = {} self._aggregation_list = {}
self._higher_is_better = {} self._higher_is_better = {}
_metric_list = DEFAULT_METRIC_REGISTRY[self._config.output_type] _metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
if self._config.metric_list is None: if self.config.metric_list is None:
# TODO: handle this in TaskConfig.__post_init__ ? # TODO: handle this in TaskConfig.__post_init__ ?
for metric_name in _metric_list: for metric_name in _metric_list:
self._metric_fn_list[metric_name] = get_metric(metric_name) self._metric_fn_list[metric_name] = get_metric(metric_name)
...@@ -543,7 +548,7 @@ class ConfigurableTask(Task): ...@@ -543,7 +548,7 @@ class ConfigurableTask(Task):
) )
self._higher_is_better[metric_name] = is_higher_better(metric_name) self._higher_is_better[metric_name] = is_higher_better(metric_name)
else: else:
for metric_config in self._config.metric_list: for metric_config in self.config.metric_list:
assert "metric" in metric_config assert "metric" in metric_config
metric_name = metric_config["metric"] metric_name = metric_config["metric"]
kwargs = { kwargs = {
...@@ -552,7 +557,7 @@ class ConfigurableTask(Task): ...@@ -552,7 +557,7 @@ class ConfigurableTask(Task):
if key not in ["metric", "aggregation", "higher_is_better"] if key not in ["metric", "aggregation", "higher_is_better"]
} }
if self._config.process_results is not None: if self.config.process_results is not None:
self._metric_fn_list[metric_name] = None self._metric_fn_list[metric_name] = None
self._metric_fn_kwargs[metric_name] = {} self._metric_fn_kwargs[metric_name] = {}
elif callable(metric_name): elif callable(metric_name):
...@@ -594,13 +599,13 @@ class ConfigurableTask(Task): ...@@ -594,13 +599,13 @@ class ConfigurableTask(Task):
) )
self._higher_is_better[metric_name] = is_higher_better(metric_name) self._higher_is_better[metric_name] = is_higher_better(metric_name)
self.download(self._config.dataset_kwargs) self.download(self.config.dataset_kwargs)
self._training_docs = None self._training_docs = None
self._fewshot_docs = None self._fewshot_docs = None
if self._config.filter_list is not None: if self.config.filter_list is not None:
self._filters = [] self._filters = []
for filter_config in self._config.filter_list: for filter_config in self.config.filter_list:
for filter_pipeline in filter_config: for filter_pipeline in filter_config:
filter_name = filter_config["name"] filter_name = filter_config["name"]
filter_functions = filter_config["filter"] filter_functions = filter_config["filter"]
...@@ -615,10 +620,10 @@ class ConfigurableTask(Task): ...@@ -615,10 +620,10 @@ class ConfigurableTask(Task):
else: else:
self._filters = [build_filter_ensemble("none", [["take_first", None]])] self._filters = [build_filter_ensemble("none", [["take_first", None]])]
if self._config.use_prompt is not None: if self.config.use_prompt is not None:
eval_logger.info(f"loading prompt {self._config.use_prompt}") eval_logger.info(f"loading prompt {self.config.use_prompt}")
self.prompt = get_prompt( self.prompt = get_prompt(
self._config.use_prompt, self.DATASET_PATH, self.DATASET_NAME self.config.use_prompt, self.DATASET_PATH, self.DATASET_NAME
) )
else: else:
self.prompt = None self.prompt = None
...@@ -645,7 +650,7 @@ class ConfigurableTask(Task): ...@@ -645,7 +650,7 @@ class ConfigurableTask(Task):
test_text = self.doc_to_text(test_doc) test_text = self.doc_to_text(test_doc)
test_target = self.doc_to_target(test_doc) test_target = self.doc_to_target(test_doc)
if self._config.doc_to_choice is not None: if self.config.doc_to_choice is not None:
test_choice = self.doc_to_choice(test_doc) test_choice = self.doc_to_choice(test_doc)
if type(test_choice) is not list: if type(test_choice) is not list:
eval_logger.error("doc_to_choice must return list") eval_logger.error("doc_to_choice must return list")
...@@ -669,22 +674,22 @@ class ConfigurableTask(Task): ...@@ -669,22 +674,22 @@ class ConfigurableTask(Task):
check_choices = test_choice check_choices = test_choice
else: else:
check_choices = [test_target] check_choices = [test_target]
if self.config.doc_to_choice is not None:
for choice in check_choices: for choice in check_choices:
choice_has_whitespace = True if " " in choice else False choice_has_whitespace = True if choice[0].isspace() else False
delimiter_has_whitespace = ( delimiter_has_whitespace = (
True if " " in self._config.target_delimiter else False True if self.config.target_delimiter[-1].isspace() else False
)
if delimiter_has_whitespace and choice_has_whitespace:
eval_logger.warning(
f'Both target_delimiter and target choice: "{choice}" have whitespace'
)
elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
eval_logger.warning(
f'Both target_delimiter and target choice: "{choice}" does not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
) )
if delimiter_has_whitespace and choice_has_whitespace:
eval_logger.warning(
f'Both target_delimiter and target choice: "{choice}" have whitespace'
)
elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
eval_logger.warning(
f'Both target_delimiter and target choice: "{choice}" does not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
)
def download(self, dataset_kwargs=None) -> None: def download(self, dataset_kwargs=None) -> None:
self.dataset = datasets.load_dataset( self.dataset = datasets.load_dataset(
path=self.DATASET_PATH, path=self.DATASET_PATH,
...@@ -693,52 +698,52 @@ class ConfigurableTask(Task): ...@@ -693,52 +698,52 @@ class ConfigurableTask(Task):
) )
def has_training_docs(self) -> bool: def has_training_docs(self) -> bool:
if self._config.training_split is not None: if self.config.training_split is not None:
return True return True
else: else:
return False return False
def has_validation_docs(self) -> bool: def has_validation_docs(self) -> bool:
if self._config.validation_split is not None: if self.config.validation_split is not None:
return True return True
else: else:
return False return False
def has_test_docs(self) -> bool: def has_test_docs(self) -> bool:
if self._config.test_split is not None: if self.config.test_split is not None:
return True return True
else: else:
return False return False
def training_docs(self) -> datasets.Dataset: def training_docs(self) -> datasets.Dataset:
if self.has_training_docs(): if self.has_training_docs():
if self._config.process_docs is not None: if self.config.process_docs is not None:
return self._config.process_docs( return self.config.process_docs(
self.dataset[self._config.training_split] self.dataset[self.config.training_split]
) )
return self.dataset[self._config.training_split] return self.dataset[self.config.training_split]
def validation_docs(self) -> datasets.Dataset: def validation_docs(self) -> datasets.Dataset:
if self.has_validation_docs(): if self.has_validation_docs():
if self._config.process_docs is not None: if self.config.process_docs is not None:
return self._config.process_docs( return self.config.process_docs(
self.dataset[self._config.validation_split] self.dataset[self.config.validation_split]
) )
return self.dataset[self._config.validation_split] return self.dataset[self.config.validation_split]
def test_docs(self) -> datasets.Dataset: def test_docs(self) -> datasets.Dataset:
if self.has_test_docs(): if self.has_test_docs():
if self._config.process_docs is not None: if self.config.process_docs is not None:
return self._config.process_docs(self.dataset[self._config.test_split]) return self.config.process_docs(self.dataset[self.config.test_split])
return self.dataset[self._config.test_split] return self.dataset[self.config.test_split]
def fewshot_docs(self): def fewshot_docs(self):
if self._config.fewshot_split is not None: if self.config.fewshot_split is not None:
return self.dataset[self._config.fewshot_split] return self.dataset[self.config.fewshot_split]
else: else:
if self._config.num_fewshot > 0: if self.config.num_fewshot > 0:
eval_logger.warning( eval_logger.warning(
f"Task '{self._config.task}': " f"Task '{self.config.task}': "
"num_fewshot > 0 but fewshot_split is None. " "num_fewshot > 0 but fewshot_split is None. "
"using preconfigured rule." "using preconfigured rule."
) )
...@@ -754,15 +759,15 @@ class ConfigurableTask(Task): ...@@ -754,15 +759,15 @@ class ConfigurableTask(Task):
return self._instances return self._instances
def should_decontaminate(self): def should_decontaminate(self):
return self._config.should_decontaminate return self.config.should_decontaminate
def doc_to_decontamination_query(self, doc): def doc_to_decontamination_query(self, doc):
if self._config.should_decontaminate: if self.config.should_decontaminate:
if self._config.doc_to_decontamination_query in self.features: if self.config.doc_to_decontamination_query in self.features:
return doc[self._config.doc_to_decontamination_query] return doc[self.config.doc_to_decontamination_query]
else: else:
return ast.literal_eval( return ast.literal_eval(
utils.apply_template(self._config.doc_to_decontamination_query, doc) utils.apply_template(self.config.doc_to_decontamination_query, doc)
) )
def _process_doc(self, doc): def _process_doc(self, doc):
...@@ -780,13 +785,13 @@ class ConfigurableTask(Task): ...@@ -780,13 +785,13 @@ class ConfigurableTask(Task):
if self.prompt is not None: if self.prompt is not None:
doc_to_text = self.prompt doc_to_text = self.prompt
else: else:
doc_to_text = self._config.doc_to_text doc_to_text = self.config.doc_to_text
if type(doc_to_text) == int: if type(doc_to_text) == int:
return doc_to_text return doc_to_text
elif type(doc_to_text) == str: elif type(doc_to_text) == str:
if doc_to_text in self.features: if doc_to_text in self.features:
# if self._config.doc_to_choice is not None: # if self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_text]] # return self.doc_to_choice(doc)[doc[doc_to_text]]
# else: # else:
return doc[doc_to_text] return doc[doc_to_text]
...@@ -805,7 +810,7 @@ class ConfigurableTask(Task): ...@@ -805,7 +810,7 @@ class ConfigurableTask(Task):
return applied_prompt[0] return applied_prompt[0]
else: else:
eval_logger.warning("Applied prompt returns empty string") eval_logger.warning("Applied prompt returns empty string")
return self._config.fewshot_delimiter return self.config.fewshot_delimiter
else: else:
print(type(doc_to_text)) print(type(doc_to_text))
raise TypeError raise TypeError
...@@ -814,13 +819,13 @@ class ConfigurableTask(Task): ...@@ -814,13 +819,13 @@ class ConfigurableTask(Task):
if self.prompt is not None: if self.prompt is not None:
doc_to_target = self.prompt doc_to_target = self.prompt
else: else:
doc_to_target = self._config.doc_to_target doc_to_target = self.config.doc_to_target
if type(doc_to_target) == int: if type(doc_to_target) == int:
return doc_to_target return doc_to_target
elif type(doc_to_target) == str: elif type(doc_to_target) == str:
if doc_to_target in self.features: if doc_to_target in self.features:
# if self._config.doc_to_choice is not None: # if self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_target]] # return self.doc_to_choice(doc)[doc[doc_to_target]]
# else: # else:
return doc[doc_to_target] return doc[doc_to_target]
...@@ -847,17 +852,17 @@ class ConfigurableTask(Task): ...@@ -847,17 +852,17 @@ class ConfigurableTask(Task):
return applied_prompt[1] return applied_prompt[1]
else: else:
eval_logger.warning("Applied prompt returns empty string") eval_logger.warning("Applied prompt returns empty string")
return self._config.fewshot_delimiter return self.config.fewshot_delimiter
else: else:
raise TypeError raise TypeError
def doc_to_choice(self, doc: Any) -> List[str]: def doc_to_choice(self, doc: Any) -> List[str]:
if self.prompt is not None: if self.prompt is not None:
doc_to_choice = self.prompt doc_to_choice = self.prompt
elif self._config.doc_to_choice is None: elif self.config.doc_to_choice is None:
eval_logger.error("doc_to_choice was called but not set in config") eval_logger.error("doc_to_choice was called but not set in config")
else: else:
doc_to_choice = self._config.doc_to_choice doc_to_choice = self.config.doc_to_choice
if type(doc_to_choice) == str: if type(doc_to_choice) == str:
return ast.literal_eval(utils.apply_template(doc_to_choice, doc)) return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
...@@ -878,8 +883,8 @@ class ConfigurableTask(Task): ...@@ -878,8 +883,8 @@ class ConfigurableTask(Task):
# in multiple_choice tasks, this should be castable to an int corresponding to the index # in multiple_choice tasks, this should be castable to an int corresponding to the index
# within the answer choices, while doc_to_target is the string version of {{answer_choices[gold]}}. # within the answer choices, while doc_to_target is the string version of {{answer_choices[gold]}}.
if self._config.gold_alias is not None: if self.config.gold_alias is not None:
doc_to_target = self._config.gold_alias doc_to_target = self.config.gold_alias
else: else:
return self.doc_to_target(doc) return self.doc_to_target(doc)
...@@ -901,7 +906,7 @@ class ConfigurableTask(Task): ...@@ -901,7 +906,7 @@ class ConfigurableTask(Task):
arguments = (self.doc_to_target(doc),) arguments = (self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
target_delimiter = self._config.target_delimiter target_delimiter = self.config.target_delimiter
if self.multiple_input: if self.multiple_input:
# If there are multiple inputs, choices are placed in the ctx # If there are multiple inputs, choices are placed in the ctx
cont = self.doc_to_target(doc) cont = self.doc_to_target(doc)
...@@ -943,15 +948,16 @@ class ConfigurableTask(Task): ...@@ -943,15 +948,16 @@ class ConfigurableTask(Task):
return request_list return request_list
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
arguments = (ctx, self._config.generation_kwargs) arguments = (ctx, self.config.generation_kwargs)
return Instance( return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
) )
def process_results(self, doc, results): def process_results(self, doc, results):
if callable(self._config.process_results):
return self._config.process_results(doc, results) if callable(self.config.process_results):
return self.config.process_results(doc, results)
result_dict = {} result_dict = {}
use_metric = list(self._metric_fn_list.keys()) use_metric = list(self._metric_fn_list.keys())
...@@ -1056,11 +1062,14 @@ class ConfigurableTask(Task): ...@@ -1056,11 +1062,14 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
if self._config.doc_to_choice is not None: if self.config.doc_to_choice is not None:
# If you set doc_to_choice, # If you set doc_to_choice,
# it assumes that doc_to_target returns a number. # it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
gold = choices[gold] gold = choices[gold]
# we expect multiple_targets to be a list.
elif self.multiple_target:
gold = list(gold)
else: else:
gold = str(gold) gold = str(gold)
...@@ -1071,6 +1080,10 @@ class ConfigurableTask(Task): ...@@ -1071,6 +1080,10 @@ class ConfigurableTask(Task):
# return true if any are true # return true if any are true
# TODO: this may break for multipLe_target, non zero-or-1 metrics # TODO: this may break for multipLe_target, non zero-or-1 metrics
scores = [] scores = []
if not isinstance(gold, list):
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
# print(gold)
gold = [gold]
for gold_option in gold: for gold_option in gold:
try: try:
result_score = self._metric_fn_list[metric]( result_score = self._metric_fn_list[metric](
......
...@@ -3,7 +3,7 @@ import string ...@@ -3,7 +3,7 @@ import string
import pickle import pickle
import traceback import traceback
from pprint import pprint from pprint import pprint
from typing import Iterator, Sequence, TypeVar from typing import Iterator, Sequence, TypeVar, List, Tuple
# This is a cpp module. Compile janitor_util.cpp with: # This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup # c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
...@@ -21,7 +21,7 @@ T = TypeVar("T") ...@@ -21,7 +21,7 @@ T = TypeVar("T")
# Implementation from nltk source # Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html # https://www.nltk.org/_modules/nltk/util.html
def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[tuple[T, ...]]: def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[Tuple[T, ...]]:
history = [] history = []
while n > 1: while n > 1:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator # PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
...@@ -70,14 +70,14 @@ def word_ngrams(s: str, n: int) -> Iterator[str]: ...@@ -70,14 +70,14 @@ def word_ngrams(s: str, n: int) -> Iterator[str]:
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python # https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
def split_indices(s: str) -> Iterator[tuple[str, tuple[int, int]]]: def split_indices(s: str) -> Iterator[Tuple[str, Tuple[int, int]]]:
"""Splits a string on whitespaces and records the indices of each in the original string. """Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...) @:return generator((word, (start_idx, end_idx)), ...)
""" """
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s)) return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
def word_ngrams_indices(s: str, n: int) -> Iterator[tuple[str, tuple[int, int]]]: def word_ngrams_indices(s: str, n: int) -> Iterator[Tuple[str, Tuple[int, int]]]:
"""Splits a string into pairs of (ngram words, their start/end indices)""" """Splits a string into pairs of (ngram words, their start/end indices)"""
tokens_with_indices = split_indices(s) tokens_with_indices = split_indices(s)
...@@ -157,7 +157,7 @@ class Janitor: ...@@ -157,7 +157,7 @@ class Janitor:
print("WARNING: Janitor running in python mode") print("WARNING: Janitor running in python mode")
return self.register_contaminant_python(dirt_string) return self.register_contaminant_python(dirt_string)
def clean(self, dirty_string: str) -> list[str]: def clean(self, dirty_string: str) -> List[str]:
"""Clean a string (e.g. a training set) by removing all ngrams previously """Clean a string (e.g. a training set) by removing all ngrams previously
registered as contaminants. Returns a list of clean chunks, or empty if registered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty""" the string was too dirty"""
...@@ -168,8 +168,8 @@ class Janitor: ...@@ -168,8 +168,8 @@ class Janitor:
return self.clean_python(dirty_string) return self.clean_python(dirty_string)
def _split_chunks( def _split_chunks(
self, dirty_string: str, dirty_parts: Sequence[tuple] self, dirty_string: str, dirty_parts: Sequence[Tuple]
) -> list[str]: ) -> List[str]:
clean_chunks = [] clean_chunks = []
splice_idx = 0 splice_idx = 0
end = -1 end = -1
...@@ -197,7 +197,7 @@ class Janitor: ...@@ -197,7 +197,7 @@ class Janitor:
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n) janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
) )
def clean_cpp(self, dirty_string: str) -> list[str]: def clean_cpp(self, dirty_string: str) -> List[str]:
contamination_indices = janitor_util.clean_ngram_with_indices( contamination_indices = janitor_util.clean_ngram_with_indices(
dirty_string, self.delete_chars, self.ngram_n dirty_string, self.delete_chars, self.ngram_n
) )
...@@ -215,7 +215,7 @@ class Janitor: ...@@ -215,7 +215,7 @@ class Janitor:
word_ngrams(self.normalize_string(dirt_string), self.ngram_n) word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
) )
def clean_python(self, dirty_string: str) -> list[str]: def clean_python(self, dirty_string: str) -> List[str]:
contamination_indices = ( contamination_indices = (
(None, *idx_pair) (None, *idx_pair)
for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n) for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
......
...@@ -11,7 +11,6 @@ import numpy as np ...@@ -11,7 +11,6 @@ import numpy as np
import lm_eval.api import lm_eval.api
import lm_eval.tasks import lm_eval.tasks
import lm_eval.benchmarks
import lm_eval.models import lm_eval.models
import lm_eval.api.metrics import lm_eval.api.metrics
import lm_eval.api.registry import lm_eval.api.registry
...@@ -120,6 +119,8 @@ def simple_evaluate( ...@@ -120,6 +119,8 @@ def simple_evaluate(
task_obj = task_dict[task_name] task_obj = task_dict[task_name]
if type(task_obj) == tuple: if type(task_obj) == tuple:
group, task_obj = task_obj group, task_obj = task_obj
if task_obj is None:
continue
config = task_obj._config config = task_obj._config
if num_fewshot is not None: if num_fewshot is not None:
...@@ -184,7 +185,7 @@ def evaluate( ...@@ -184,7 +185,7 @@ def evaluate(
:param lm: obj :param lm: obj
Language Model Language Model
:param task_dict: dict[str, Task] :param task_dict: dict[str, Task]
Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise. Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
:param limit: int, optional :param limit: int, optional
Limit the number of examples per task (only use this for testing) Limit the number of examples per task (only use this for testing)
:param bootstrap_iters: :param bootstrap_iters:
...@@ -209,23 +210,30 @@ def evaluate( ...@@ -209,23 +210,30 @@ def evaluate(
samples = collections.defaultdict(list) samples = collections.defaultdict(list)
# tracks all Instances/requests a model must generate output on. # tracks all Instances/requests a model must generate output on.
requests = collections.defaultdict(list) requests = collections.defaultdict(list)
# Stores task scores based on task grouping. # Aggregated task scores presented with groups
aggregate = collections.defaultdict(dict) results_agg = collections.defaultdict(dict)
# tracks if a task was chosen via user selecting a group containing it # Aggregated groups scores only
task_groups = collections.defaultdict(dict) groups_agg = collections.defaultdict(dict)
# stores the amount to pad out reqs per req. type so that # stores the amount to pad out reqs per req. type so that
# number of fwd passes per distributed rank is equal # number of fwd passes per distributed rank is equal
padding_requests = collections.defaultdict(int) padding_requests = collections.defaultdict(int)
# store the hierarchy to do proper ordering
# Stores group related keys and values for group-aggregation task_hierarchy = collections.defaultdict(list)
task_groups = collections.defaultdict(dict) # store the ordering of tasks and groups
task_order = collections.defaultdict(int)
# store the aggregation for aggregating across tasks in the same group
sample_agg_fn = collections.defaultdict(dict)
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if type(task) == tuple: if type(task) == tuple:
group, task = task group_name, task = task
task_groups[task_name] = group task_hierarchy[group_name].append(task_name)
aggregate[task_name] = {} else:
task_hierarchy[task_name] = []
if task is None:
continue
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
configs[task_name] = dict(task.dump_config()) configs[task_name] = dict(task.dump_config())
...@@ -301,6 +309,8 @@ def evaluate( ...@@ -301,6 +309,8 @@ def evaluate(
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if type(task) == tuple: if type(task) == tuple:
group, task = task group, task = task
if task is None:
continue
task.apply_filters() task.apply_filters()
### Collect values of metrics on all datapoints ### ### Collect values of metrics on all datapoints ###
...@@ -310,6 +320,8 @@ def evaluate( ...@@ -310,6 +320,8 @@ def evaluate(
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if type(task) == tuple: if type(task) == tuple:
group, task = task group, task = task
if task is None:
continue
# TODO: make it possible to use a different metric per filter # TODO: make it possible to use a different metric per filter
# iterate over different filters used # iterate over different filters used
for key in task.instances[0].filtered_resps.keys(): for key in task.instances[0].filtered_resps.keys():
...@@ -396,27 +408,64 @@ def evaluate( ...@@ -396,27 +408,64 @@ def evaluate(
vals = vals_torch vals = vals_torch
if lm.rank == 0: if lm.rank == 0:
### Get task ordering for correct sample-wide aggregation
group_to_task = {}
for group in task_hierarchy.keys():
if group not in task_order:
task_order[group] = 0
if len(task_hierarchy[group]) > 0:
group_to_task[group] = task_hierarchy[group].copy()
for task in task_hierarchy[group]:
if task in task_order:
task_order[task] += 1
else:
task_order[task] = 1 + task_order[group]
if task in task_hierarchy:
group_to_task[group].remove(task)
group_to_task[group].extend(task_hierarchy[task])
task_to_group = {}
for group in group_to_task:
for task in group_to_task[group]:
if task in task_to_group:
task_to_group[task].append(group)
else:
task_to_group[task] = [group]
### Aggregate results over all datapoints ### ### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
metric_key = metric + "," + key
if type(task) == tuple: if type(task) == tuple:
group, task = task group_name, task = task
task_score = task.aggregation()[metric](items) else:
results[task_name][metric + "," + key] = task_score group_name = None
# Need to put back in results agg_fn = task.aggregation()[metric]
# pythia | acc task_score = agg_fn(items)
# | perplexity
# | word_perplexity if group_name is not None:
# | byte_perplexity sample_metric_key = metric + "(sample agg)," + key
# | bits_per_byte for grouping in task_to_group[task_name]:
if task_name in task_groups: if metric_key in results[grouping]:
group_name = task_groups[task_name] results[grouping][metric_key].append(task_score)
if metric in list(aggregate[group_name].keys()): else:
aggregate[group_name][metric].append(task_score) results[grouping][metric_key] = [task_score]
else:
aggregate[group_name][metric] = [task_score] if sample_metric_key in results[grouping]:
results[grouping][sample_metric_key] += items
else:
results[grouping][sample_metric_key] = items.copy()
sample_agg_fn[grouping][sample_metric_key] = agg_fn
results[task_name][metric_key] = task_score
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
...@@ -431,19 +480,38 @@ def evaluate( ...@@ -431,19 +480,38 @@ def evaluate(
if stderr is not None: if stderr is not None:
results[task_name][metric + "_stderr" + "," + key] = stderr(items) results[task_name][metric + "_stderr" + "," + key] = stderr(items)
if bool(aggregate): if bool(results):
for group in aggregate.keys(): for task_or_group in results.keys():
for metric in aggregate[group].keys(): for metric in results[task_or_group].keys():
aggregate[group][metric] = np.average(aggregate[group][metric]) if type(results[task_or_group][metric]) == list:
versions[group] = "N/A" if "(sample agg)" in metric:
results[task_or_group][metric] = sample_agg_fn[
task_or_group
][metric](results[task_or_group][metric])
else:
results[task_or_group][metric] = np.average(
results[task_or_group][metric]
)
versions[task_or_group] = "N/A"
for task_name, task in task_dict.items():
if type(task) == tuple:
group_name, task = task
order = task_order[group_name]
tabbed_name = "-" * order + group_name
results_agg[tabbed_name] = results[group_name]
versions[tabbed_name] = versions[group_name]
if order == 0:
groups_agg[group_name] = results[group_name]
order = task_order[task_name]
tabbed_name = "-" * order + task_name
results_agg[tabbed_name] = results[task_name]
versions[tabbed_name] = versions[task_name]
results_dict = { results_dict = {
"results": dict(sorted(results.items())), "results": dict(results_agg.items()),
**( **({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
{"aggregate": dict(sorted(aggregate.items()))}
if bool(aggregate)
else {}
),
"configs": dict(sorted(configs.items())), "configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())), "versions": dict(sorted(versions.items())),
} }
......
...@@ -107,17 +107,20 @@ class HFLM(LM): ...@@ -107,17 +107,20 @@ class HFLM(LM):
if not (parallelize or accelerator.num_processes > 1): if not (parallelize or accelerator.num_processes > 1):
# use user-passed device # use user-passed device
device_list = set( device_list = set(
["cuda", "cpu", "mps"] ["cuda", "cpu"]
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
+ ["mps", "mps:0"]
) )
if device: if device:
if device not in device_list: if device not in device_list:
device = int(device) device = int(device)
self._device = torch.device(device) self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'") eval_logger.info(f"Using device '{device}'")
if device == "mps": if device in ("mps", "mps:0") and "dev" not in torch.__version__:
eval_logger.info( eval_logger.info(
"MPS is still in beta and only supports float32; setting dtype to float32." "MPS: Setting dtype to float32. To use float16 with MPS, please install a nightly build of "
"PyTorch: pip3 install --pre torch torchvision torchaudio --index-url "
"https://download.pytorch.org/whl/nightly/cpu"
) )
else: else:
eval_logger.info("Device not specified") eval_logger.info("Device not specified")
......
import ast
from typing import Dict
from lm_eval import utils from lm_eval import utils
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
...@@ -5,7 +8,7 @@ from lm_eval.logger import eval_logger ...@@ -5,7 +8,7 @@ from lm_eval.logger import eval_logger
# Stores prompts in a dictionary indexed by 2 levels: # Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name. # prompt category name, and prompt name.
# This allows us to access prompts # This allows us to access prompts
PROMPT_REGISTRY: dict[str, dict[str, str]] = { PROMPT_REGISTRY: Dict[str, Dict[str, str]] = {
"qa-basic": { "qa-basic": {
"question-newline-answer": "Question: {{question}}\nAnswer:", "question-newline-answer": "Question: {{question}}\nAnswer:",
"q-newline-a": "Q: {{question}}\nA:", "q-newline-a": "Q: {{question}}\nA:",
...@@ -88,6 +91,14 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwa ...@@ -88,6 +91,14 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwa
prompt_name, prompt_yaml_file["prompts"].keys() prompt_name, prompt_yaml_file["prompts"].keys()
) )
category_name, *prompt_name = use_prompt.split(":")
# TODO allow to multiple prompt naming
# if len(prompt_name) > 1:
# prompt_list = []
# for prompt in prompt_name:
# prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names))
# else:
prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
return [":".join([category_name, prompt]) for prompt in prompt_list] return [":".join([category_name, prompt]) for prompt in prompt_list]
......
...@@ -16,7 +16,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -16,7 +16,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] MCTACO - [x] MCTACO
- [x] Pubmed QA - [x] Pubmed QA
- [x] SciQ - [x] SciQ
- [ ] QASPER - [x] QASPER
- [x] QA4MRE - [x] QA4MRE
- [x] TriviaQA - [x] TriviaQA
- [x] AI2 ARC - [x] AI2 ARC
...@@ -36,7 +36,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -36,7 +36,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] TruthfulQA (mc1) - [x] TruthfulQA (mc1)
- [x] TruthfulQA (mc2) - [x] TruthfulQA (mc2)
- [x] TruthfulQA (gen) - [x] TruthfulQA (gen)
- [ ] MuTual - [x] MuTual
- [ ] Hendrycks Math (Hailey) - [ ] Hendrycks Math (Hailey)
- [x] Asdiv - [x] Asdiv
- [ ] GSM8k - [ ] GSM8k
......
import os import os
import yaml import yaml
from typing import List, Union from typing import List, Union, Dict
from lm_eval import utils from lm_eval import utils
from lm_eval import prompts from lm_eval import prompts
...@@ -15,7 +15,7 @@ from lm_eval.api.registry import ( ...@@ -15,7 +15,7 @@ from lm_eval.api.registry import (
) )
def register_configurable_task(config: dict[str, str]) -> int: def register_configurable_task(config: Dict[str, str]) -> int:
SubClass = type( SubClass = type(
config["task"] + "ConfigurableTask", config["task"] + "ConfigurableTask",
(ConfigurableTask,), (ConfigurableTask,),
...@@ -38,7 +38,35 @@ def register_configurable_task(config: dict[str, str]) -> int: ...@@ -38,7 +38,35 @@ def register_configurable_task(config: dict[str, str]) -> int:
return 0 return 0
def check_prompt_config(config: dict[str, str]) -> List[dict[str, str]]: def register_configurable_group(config: Dict[str, str]) -> int:
group = config["group"]
all_task_list = config["task"]
config_list = [task for task in all_task_list if type(task) != str]
task_list = [task for task in all_task_list if type(task) == str]
for task_config in config_list:
var_configs = check_prompt_config(
{
**task_config,
**{"group": group},
}
)
for config in var_configs:
register_configurable_task(config)
task_names = utils.pattern_match(task_list, ALL_TASKS)
for task in task_names:
if (task in TASK_REGISTRY) or (task in GROUP_REGISTRY):
if group in GROUP_REGISTRY:
GROUP_REGISTRY[group].append(task)
else:
GROUP_REGISTRY[group] = [task]
ALL_TASKS.add(group)
return 0
def check_prompt_config(config: Dict[str, str]) -> List[Dict[str, str]]:
all_configs = [] all_configs = []
if "use_prompt" in config: if "use_prompt" in config:
prompt_list = prompts.load_prompt_list( prompt_list = prompts.load_prompt_list(
...@@ -69,14 +97,14 @@ def check_prompt_config(config: dict[str, str]) -> List[dict[str, str]]: ...@@ -69,14 +97,14 @@ def check_prompt_config(config: dict[str, str]) -> List[dict[str, str]]:
return all_configs return all_configs
def get_task_name_from_config(task_config: dict[str, str]) -> str: def get_task_name_from_config(task_config: Dict[str, str]) -> str:
if "dataset_name" in task_config: if "dataset_name" in task_config:
return "{dataset_path}_{dataset_name}".format(**task_config) return "{dataset_path}_{dataset_name}".format(**task_config)
else: else:
return "{dataset_path}".format(**task_config) return "{dataset_path}".format(**task_config)
def include_task_folder(task_dir: str) -> None: def include_task_folder(task_dir: str, register_task=True) -> None:
""" """
Calling this function Calling this function
""" """
...@@ -87,9 +115,16 @@ def include_task_folder(task_dir: str) -> None: ...@@ -87,9 +115,16 @@ def include_task_folder(task_dir: str) -> None:
yaml_path = os.path.join(root, f) yaml_path = os.path.join(root, f)
try: try:
config = utils.load_yaml_config(yaml_path) config = utils.load_yaml_config(yaml_path)
all_configs = check_prompt_config(config)
for config in all_configs: if register_task:
register_configurable_task(config) all_configs = check_prompt_config(config)
for config in all_configs:
register_configurable_task(config)
else:
# If a `task` in config is a list,
# that means it's a benchmark
if type(config["task"]) == list:
register_configurable_group(config)
except Exception as error: except Exception as error:
eval_logger.warning( eval_logger.warning(
...@@ -102,6 +137,8 @@ def include_task_folder(task_dir: str) -> None: ...@@ -102,6 +137,8 @@ def include_task_folder(task_dir: str) -> None:
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/" task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_task_folder(task_dir) include_task_folder(task_dir)
# Register Benchmarks after all tasks have been added
include_task_folder(task_dir, register_task=False)
def get_task(task_name, config): def get_task(task_name, config):
...@@ -128,7 +165,7 @@ def get_task_name_from_object(task_object): ...@@ -128,7 +165,7 @@ def get_task_name_from_object(task_object):
# TODO: pass num_fewshot and other cmdline overrides in a better way # TODO: pass num_fewshot and other cmdline overrides in a better way
def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs): def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
config = {**kwargs} config = {**kwargs}
...@@ -136,6 +173,9 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs): ...@@ -136,6 +173,9 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
task_name_from_config_dict = {} task_name_from_config_dict = {}
task_name_from_object_dict = {} task_name_from_object_dict = {}
if type(task_name_list) != list:
task_name_list = [task_name_list]
for task_element in task_name_list: for task_element in task_name_list:
if isinstance(task_element, str): if isinstance(task_element, str):
...@@ -143,12 +183,20 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs): ...@@ -143,12 +183,20 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
group_name = task_element group_name = task_element
for task_name in GROUP_REGISTRY[task_element]: for task_name in GROUP_REGISTRY[task_element]:
if task_name not in task_name_from_registry_dict: if task_name not in task_name_from_registry_dict:
task_obj = get_task_dict(task_name)
if task_name in task_obj.keys():
task_dict = {
task_name: (group_name, task_obj[task_name]),
}
else:
task_dict = {
task_name: (group_name, None),
**task_obj,
}
task_name_from_registry_dict = { task_name_from_registry_dict = {
**task_name_from_registry_dict, **task_name_from_registry_dict,
task_name: ( **task_dict,
group_name,
get_task(task_name=task_name, config=config),
),
} }
else: else:
task_name = task_element task_name = task_element
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment