Commit 50e99bd7 authored by Herbie Bradley's avatar Herbie Bradley
Browse files

Merge remote-tracking branch 'origin/big-refactor' into calibration

parents 3d4c4cd6 a3252ed7
name: Tasks Modified # name: Tasks Modified
on: # on:
push: # push:
branches: # branches:
- 'big-refactor*' # - 'big-refactor*'
pull_request: # pull_request:
branches: # branches:
- 'big-refactor*' # - 'big-refactor*'
workflow_dispatch: # workflow_dispatch:
# comment/edit out the above to stop/change the triggers # # comment/edit out the above to stop/change the triggers
jobs: # jobs:
changed_files: # changed_files:
runs-on: ubuntu-latest # windows-latest || macos-latest # runs-on: ubuntu-latest # windows-latest || macos-latest
timeout-minutes: 120 # timeout-minutes: 120
name: Scan for changed tasks # name: Scan for changed tasks
steps: # steps:
- name: checkout # - name: checkout
uses: actions/checkout@v3 # uses: actions/checkout@v3
with: # with:
fetch-depth: 2 # OR "2" -> To retrieve the preceding commit. # fetch-depth: 2 # OR "2" -> To retrieve the preceding commit.
# Uses the tj-actions/changed-files@v37 action to check for changes. # # Uses the tj-actions/changed-files@v37 action to check for changes.
# Outputs provided here: https://github.com/tj-actions/changed-files#outputs # # Outputs provided here: https://github.com/tj-actions/changed-files#outputs
# The `files_yaml` input optionally takes a yaml string to specify filters, # # The `files_yaml` input optionally takes a yaml string to specify filters,
# and prepends the filter name to the standard output names. # # and prepends the filter name to the standard output names.
- name: Check task folders # - name: Check task folders
id: changed-tasks # id: changed-tasks
uses: tj-actions/changed-files@v37.1.2 # uses: tj-actions/changed-files@v37.1.2
with: # with:
# tasks checks the tasks folder and api checks the api folder for changes # # tasks checks the tasks folder and api checks the api folder for changes
files_yaml: | # files_yaml: |
tasks: # tasks:
- lm_eval/tasks/** # - lm_eval/tasks/**
api: # api:
- lm_eval/api/** # - lm_eval/api/**
write_output_files: true # write_output_files: true
# The next step is optional; the files are written to the workspace by default (above). # # The next step is optional; the files are written to the workspace by default (above).
# so it's just for debugging # # so it's just for debugging
- name: Run Tests # - name: Run Tests
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true' # if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
run: | # run: |
echo .github/outputs/tasks_all_changed_and_modified_files.txt >> 'GITHUB_ENV' # echo .github/outputs/tasks_all_changed_and_modified_files.txt >> 'GITHUB_ENV'
echo "One or more test file(s) has changed." # echo "One or more test file(s) has changed."
echo "List of all the files that have changed: ${{ steps.changed-tasks.outputs.tasks_all_modified_files }}" # echo "List of all the files that have changed: ${{ steps.changed-tasks.outputs.tasks_all_modified_files }}"
- name: Set up Python 3.9 # - name: Set up Python 3.9
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true' # if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
uses: actions/setup-python@v4 # uses: actions/setup-python@v4
with: # with:
python-version: 3.9 # python-version: 3.9
cache: 'pip' # cache: 'pip'
cache-dependency-path: setup.py # cache-dependency-path: setup.py
- name: Install dependencies # - name: Install dependencies
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true' # if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
run: | # run: |
python -m pip install --upgrade pip # python -m pip install --upgrade pip
pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu # pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies # # Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt # # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi # # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest # - name: Test with pytest
# if new tasks are added, run tests on them # # if new tasks are added, run tests on them
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' # if: steps.changed-tasks.outputs.tasks_any_modified == 'true'
run: python -m pytest tests/test_tasks.py -s -vv -n=auto # run: python -m pytest tests/test_tasks.py -s -vv
# if api is modified, run tests on it # # if api is modified, run tests on it
- name: Test more tasks with pytest # - name: Test more tasks with pytest
env: # env:
API: true # API: true
if: steps.changed-tasks.outputs.api_any_modified == 'true' # if: steps.changed-tasks.outputs.api_any_modified == 'true'
run: python -m pytest tests/test_tasks.py -s -vv -n=auto # run: python -m pytest tests/test_tasks.py -s -vv
...@@ -40,39 +40,38 @@ jobs: ...@@ -40,39 +40,38 @@ jobs:
flake8 . --count --select=F,E9,E71,E72,E501,E112,E113,W6 --extend-ignore=F541 --show-source --statistics --exit-zero flake8 . --count --select=F,E9,E71,E72,E501,E112,E113,W6 --extend-ignore=F541 --show-source --statistics --exit-zero
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# mypy turned off for now # # mypy turned off for now
# - name: Lint with mypy # - name: Lint with mypy
# run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable # run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable
# Job 2 # Job 2
testcpu: # testcpu:
name: CPU Tests # name: CPU Tests
runs-on: ubuntu-latest # runs-on: ubuntu-latest
strategy: # strategy:
matrix: # matrix:
python-version: [ "3.9", "3.10", "3.11" ] # python-version: [ "3.8", "3.9", "3.10", "3.11" ]
timeout-minutes: 30 # timeout-minutes: 30
# steps:
steps: # - name: Checkout Code
- name: Checkout Code # uses: actions/checkout@v3
uses: actions/checkout@v3 # - name: Set up Python ${{ matrix.python-version }}
- name: Set up Python ${{ matrix.python-version }} # uses: actions/setup-python@v4
uses: actions/setup-python@v4 # with:
with: # python-version: ${{ matrix.python-version }}
python-version: ${{ matrix.python-version }} # cache: pip
cache: pip # cache-dependency-path: setup.py
cache-dependency-path: setup.py # - name: Install dependencies
- name: Install dependencies # run: |
run: | # python -m pip install --upgrade pip
python -m pip install --upgrade pip # pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu # # Install optional git dependencies
# Install optional git dependencies # # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt # # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi # - name: Test with pytest
- name: Test with pytest # run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/tests_master --ignore=tests/extra
run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/tests_master --ignore=tests/extra # - name: Archive artifacts
- name: Archive artifacts # uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v3 # with:
with: # name: output_results
name: output_results # path: |
path: | # test_logs/*
test_logs/*
import abc import abc
import os import os
from typing import Union, List, Tuple import torch
from typing import Union, List, Tuple, Optional, Type, TypeVar
from sqlitedict import SqliteDict from sqlitedict import SqliteDict
import json import json
import hashlib import hashlib
...@@ -11,6 +12,8 @@ from tqdm import tqdm ...@@ -11,6 +12,8 @@ from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
T = TypeVar("T", bound="LM")
class LM(abc.ABC): class LM(abc.ABC):
def __init__(self) -> None: def __init__(self) -> None:
...@@ -111,11 +114,28 @@ class LM(abc.ABC): ...@@ -111,11 +114,28 @@ class LM(abc.ABC):
pass pass
@classmethod @classmethod
def create_from_arg_string(cls, arg_string, additional_config=None): def create_from_arg_string(
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
) -> T:
"""
Creates an instance of the LM class using the given argument string and additional config.
Parameters:
- arg_string: A string containing arguments in the format key1=value1,key2=value2.
- additional_config: Optional dictionary containing additional configuration parameters.
Returns:
- Instance of the LM class.
"""
additional_config = {} if additional_config is None else additional_config additional_config = {} if additional_config is None else additional_config
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None} args2 = {k: v for k, v in additional_config.items() if v is not None}
if args2.get("device") == "mps" or args.get("device") == "mps": # TODO: delete once float16 MPS is fixed in torch stable
if (
args2.get("device") in ("mps", "mps:0")
or args.get("device") in ("mps", "mps:0")
and "dev" not in torch.__version__
):
args["dtype"] = "float32" args["dtype"] = "float32"
return cls(**args, **args2) return cls(**args, **args2)
......
...@@ -674,11 +674,11 @@ class ConfigurableTask(Task): ...@@ -674,11 +674,11 @@ class ConfigurableTask(Task):
check_choices = test_choice check_choices = test_choice
else: else:
check_choices = [test_target] check_choices = [test_target]
if self.config.doc_to_choice is not None:
for choice in check_choices: for choice in check_choices:
choice_has_whitespace = True if " " in choice else False choice_has_whitespace = True if choice[0].isspace() else False
delimiter_has_whitespace = ( delimiter_has_whitespace = (
True if " " in self.config.target_delimiter else False True if self.config.target_delimiter[-1].isspace() else False
) )
if delimiter_has_whitespace and choice_has_whitespace: if delimiter_has_whitespace and choice_has_whitespace:
...@@ -1080,6 +1080,9 @@ class ConfigurableTask(Task): ...@@ -1080,6 +1080,9 @@ class ConfigurableTask(Task):
# it assumes that doc_to_target returns a number. # it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
gold = choices[gold] gold = choices[gold]
# we expect multiple_targets to be a list.
elif self.multiple_target:
gold = list(gold)
else: else:
gold = str(gold) gold = str(gold)
...@@ -1090,6 +1093,10 @@ class ConfigurableTask(Task): ...@@ -1090,6 +1093,10 @@ class ConfigurableTask(Task):
# return true if any are true # return true if any are true
# TODO: this may break for multipLe_target, non zero-or-1 metrics # TODO: this may break for multipLe_target, non zero-or-1 metrics
scores = [] scores = []
if not isinstance(gold, list):
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
# print(gold)
gold = [gold]
for gold_option in gold: for gold_option in gold:
try: try:
result_score = self._metric_fn_list[metric]( result_score = self._metric_fn_list[metric](
......
import os
import yaml
from lm_eval import utils
from lm_eval.tasks import register_configurable_task, check_prompt_config
from lm_eval.logger import eval_logger
from lm_eval.api.registry import (
TASK_REGISTRY,
GROUP_REGISTRY,
ALL_TASKS,
)
def include_benchmarks(task_dir: str) -> None:
for root, subdirs, file_list in os.walk(task_dir):
if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
for f in file_list:
if f.endswith(".yaml"):
try:
benchmark_path = os.path.join(root, f)
with open(benchmark_path, "rb") as file:
yaml_config = yaml.full_load(file)
assert "group" in yaml_config
group = yaml_config["group"]
all_task_list = yaml_config["task"]
config_list = [
task for task in all_task_list if type(task) != str
]
task_list = [
task for task in all_task_list if type(task) == str
]
for task_config in config_list:
var_configs = check_prompt_config(
{
**task_config,
**{"group": group},
}
)
for config in var_configs:
register_configurable_task(config)
task_names = utils.pattern_match(task_list, ALL_TASKS)
for task in task_names:
if task in TASK_REGISTRY:
if group in GROUP_REGISTRY:
GROUP_REGISTRY[group].append(task)
else:
GROUP_REGISTRY[group] = [task]
ALL_TASKS.add(group)
except Exception as error:
eval_logger.warning(
"Failed to load benchmark in\n"
f" {benchmark_path}\n"
" Benchmark will not be added to registry\n"
f" Error: {error}"
)
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_benchmarks(task_dir)
...@@ -3,7 +3,7 @@ import string ...@@ -3,7 +3,7 @@ import string
import pickle import pickle
import traceback import traceback
from pprint import pprint from pprint import pprint
from typing import Iterator, Sequence, TypeVar from typing import Iterator, Sequence, TypeVar, List, Tuple
# This is a cpp module. Compile janitor_util.cpp with: # This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup # c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
...@@ -21,7 +21,7 @@ T = TypeVar("T") ...@@ -21,7 +21,7 @@ T = TypeVar("T")
# Implementation from nltk source # Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html # https://www.nltk.org/_modules/nltk/util.html
def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[tuple[T, ...]]: def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[Tuple[T, ...]]:
history = [] history = []
while n > 1: while n > 1:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator # PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
...@@ -70,14 +70,14 @@ def word_ngrams(s: str, n: int) -> Iterator[str]: ...@@ -70,14 +70,14 @@ def word_ngrams(s: str, n: int) -> Iterator[str]:
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python # https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
def split_indices(s: str) -> Iterator[tuple[str, tuple[int, int]]]: def split_indices(s: str) -> Iterator[Tuple[str, Tuple[int, int]]]:
"""Splits a string on whitespaces and records the indices of each in the original string. """Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...) @:return generator((word, (start_idx, end_idx)), ...)
""" """
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s)) return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
def word_ngrams_indices(s: str, n: int) -> Iterator[tuple[str, tuple[int, int]]]: def word_ngrams_indices(s: str, n: int) -> Iterator[Tuple[str, Tuple[int, int]]]:
"""Splits a string into pairs of (ngram words, their start/end indices)""" """Splits a string into pairs of (ngram words, their start/end indices)"""
tokens_with_indices = split_indices(s) tokens_with_indices = split_indices(s)
...@@ -157,7 +157,7 @@ class Janitor: ...@@ -157,7 +157,7 @@ class Janitor:
print("WARNING: Janitor running in python mode") print("WARNING: Janitor running in python mode")
return self.register_contaminant_python(dirt_string) return self.register_contaminant_python(dirt_string)
def clean(self, dirty_string: str) -> list[str]: def clean(self, dirty_string: str) -> List[str]:
"""Clean a string (e.g. a training set) by removing all ngrams previously """Clean a string (e.g. a training set) by removing all ngrams previously
registered as contaminants. Returns a list of clean chunks, or empty if registered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty""" the string was too dirty"""
...@@ -168,8 +168,8 @@ class Janitor: ...@@ -168,8 +168,8 @@ class Janitor:
return self.clean_python(dirty_string) return self.clean_python(dirty_string)
def _split_chunks( def _split_chunks(
self, dirty_string: str, dirty_parts: Sequence[tuple] self, dirty_string: str, dirty_parts: Sequence[Tuple]
) -> list[str]: ) -> List[str]:
clean_chunks = [] clean_chunks = []
splice_idx = 0 splice_idx = 0
end = -1 end = -1
...@@ -197,7 +197,7 @@ class Janitor: ...@@ -197,7 +197,7 @@ class Janitor:
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n) janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
) )
def clean_cpp(self, dirty_string: str) -> list[str]: def clean_cpp(self, dirty_string: str) -> List[str]:
contamination_indices = janitor_util.clean_ngram_with_indices( contamination_indices = janitor_util.clean_ngram_with_indices(
dirty_string, self.delete_chars, self.ngram_n dirty_string, self.delete_chars, self.ngram_n
) )
...@@ -215,7 +215,7 @@ class Janitor: ...@@ -215,7 +215,7 @@ class Janitor:
word_ngrams(self.normalize_string(dirt_string), self.ngram_n) word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
) )
def clean_python(self, dirty_string: str) -> list[str]: def clean_python(self, dirty_string: str) -> List[str]:
contamination_indices = ( contamination_indices = (
(None, *idx_pair) (None, *idx_pair)
for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n) for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
......
...@@ -118,6 +118,8 @@ def simple_evaluate( ...@@ -118,6 +118,8 @@ def simple_evaluate(
task_obj = task_dict[task_name] task_obj = task_dict[task_name]
if type(task_obj) == tuple: if type(task_obj) == tuple:
group, task_obj = task_obj group, task_obj = task_obj
if task_obj is None:
continue
config = task_obj._config config = task_obj._config
if num_fewshot is not None: if num_fewshot is not None:
...@@ -207,23 +209,30 @@ def evaluate( ...@@ -207,23 +209,30 @@ def evaluate(
samples = collections.defaultdict(list) samples = collections.defaultdict(list)
# tracks all Instances/requests a model must generate output on. # tracks all Instances/requests a model must generate output on.
requests = collections.defaultdict(list) requests = collections.defaultdict(list)
# Stores task scores based on task grouping. # Aggregated task scores presented with groups
aggregate = collections.defaultdict(dict) results_agg = collections.defaultdict(dict)
# tracks if a task was chosen via user selecting a group containing it # Aggregated groups scores only
task_groups = collections.defaultdict(dict) groups_agg = collections.defaultdict(dict)
# stores the amount to pad out reqs per req. type so that # stores the amount to pad out reqs per req. type so that
# number of fwd passes per distributed rank is equal # number of fwd passes per distributed rank is equal
padding_requests = collections.defaultdict(int) padding_requests = collections.defaultdict(int)
# store the hierarchy to do proper ordering
# Stores group related keys and values for group-aggregation task_hierarchy = collections.defaultdict(list)
task_groups = collections.defaultdict(dict) # store the ordering of tasks and groups
task_order = collections.defaultdict(int)
# store the aggregation for aggregating across tasks in the same group
sample_agg_fn = collections.defaultdict(dict)
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if type(task) == tuple: if type(task) == tuple:
group, task = task group_name, task = task
task_groups[task_name] = group task_hierarchy[group_name].append(task_name)
aggregate[task_name] = {} else:
task_hierarchy[task_name] = []
if task is None:
continue
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
configs[task_name] = dict(task.dump_config()) configs[task_name] = dict(task.dump_config())
...@@ -299,6 +308,8 @@ def evaluate( ...@@ -299,6 +308,8 @@ def evaluate(
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if type(task) == tuple: if type(task) == tuple:
group, task = task group, task = task
if task is None:
continue
task.apply_filters() task.apply_filters()
### Collect values of metrics on all datapoints ### ### Collect values of metrics on all datapoints ###
...@@ -308,6 +319,8 @@ def evaluate( ...@@ -308,6 +319,8 @@ def evaluate(
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if type(task) == tuple: if type(task) == tuple:
group, task = task group, task = task
if task is None:
continue
# TODO: make it possible to use a different metric per filter # TODO: make it possible to use a different metric per filter
# iterate over different filters used # iterate over different filters used
for key in task.instances[0].filtered_resps.keys(): for key in task.instances[0].filtered_resps.keys():
...@@ -468,27 +481,62 @@ def evaluate( ...@@ -468,27 +481,62 @@ def evaluate(
vals = vals_torch vals = vals_torch
if lm.rank == 0: if lm.rank == 0:
### Get task ordering for correct sample-wide aggregation
group_to_task = {}
for group in task_hierarchy.keys():
if group not in task_order:
task_order[group] = 0
if len(task_hierarchy[group]) > 0:
group_to_task[group] = task_hierarchy[group].copy()
for task in task_hierarchy[group]:
if task in task_order:
task_order[task] += 1
else:
task_order[task] = 1 + task_order[group]
if task in task_hierarchy:
group_to_task[group].remove(task)
group_to_task[group].extend(task_hierarchy[task])
task_to_group = {}
for group in group_to_task:
for task in group_to_task[group]:
if task in task_to_group:
task_to_group[task].append(group)
else:
task_to_group[task] = [group]
### Aggregate results over all datapoints ### ### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
metric_key = metric + "," + key
if type(task) == tuple: if type(task) == tuple:
group, task = task group_name, task = task
task_score = task.aggregation()[metric](items) else:
results[task_name][metric + "," + key] = task_score group_name = None
# Need to put back in results agg_fn = task.aggregation()[metric]
# pythia | acc task_score = agg_fn(items)
# | perplexity
# | word_perplexity if group_name is not None:
# | byte_perplexity sample_metric_key = metric + "(sample agg)," + key
# | bits_per_byte for grouping in task_to_group[task_name]:
if task_name in task_groups: if metric_key in results[grouping]:
group_name = task_groups[task_name] results[grouping][metric_key].append(task_score)
if metric in list(aggregate[group_name].keys()):
aggregate[group_name][metric].append(task_score)
else: else:
aggregate[group_name][metric] = [task_score] results[grouping][metric_key] = [task_score]
if sample_metric_key in results[grouping]:
results[grouping][sample_metric_key] += items
else:
results[grouping][sample_metric_key] = items.copy()
sample_agg_fn[grouping][sample_metric_key] = agg_fn
results[task_name][metric_key] = task_score
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
...@@ -503,19 +551,38 @@ def evaluate( ...@@ -503,19 +551,38 @@ def evaluate(
if stderr is not None: if stderr is not None:
results[task_name][metric + "_stderr" + "," + key] = stderr(items) results[task_name][metric + "_stderr" + "," + key] = stderr(items)
if bool(aggregate): if bool(results):
for group in aggregate.keys(): for task_or_group in results.keys():
for metric in aggregate[group].keys(): for metric in results[task_or_group].keys():
aggregate[group][metric] = np.average(aggregate[group][metric]) if type(results[task_or_group][metric]) == list:
versions[group] = "N/A" if "(sample agg)" in metric:
results[task_or_group][metric] = sample_agg_fn[
task_or_group
][metric](results[task_or_group][metric])
else:
results[task_or_group][metric] = np.average(
results[task_or_group][metric]
)
versions[task_or_group] = "N/A"
for task_name, task in task_dict.items():
if type(task) == tuple:
group_name, task = task
order = task_order[group_name]
tabbed_name = "-" * order + group_name
results_agg[tabbed_name] = results[group_name]
versions[tabbed_name] = versions[group_name]
if order == 0:
groups_agg[group_name] = results[group_name]
order = task_order[task_name]
tabbed_name = "-" * order + task_name
results_agg[tabbed_name] = results[task_name]
versions[tabbed_name] = versions[task_name]
results_dict = { results_dict = {
"results": dict(sorted(results.items())), "results": dict(results_agg.items()),
**( **({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
{"aggregate": dict(sorted(aggregate.items()))}
if bool(aggregate)
else {}
),
"configs": dict(sorted(configs.items())), "configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())), "versions": dict(sorted(versions.items())),
} }
......
...@@ -101,17 +101,20 @@ class HFLM(LM): ...@@ -101,17 +101,20 @@ class HFLM(LM):
if not (parallelize or accelerator.num_processes > 1): if not (parallelize or accelerator.num_processes > 1):
# use user-passed device # use user-passed device
device_list = set( device_list = set(
["cuda", "cpu", "mps"] ["cuda", "cpu"]
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
+ ["mps", "mps:0"]
) )
if device: if device:
if device not in device_list: if device not in device_list:
device = int(device) device = int(device)
self._device = torch.device(device) self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'") eval_logger.info(f"Using device '{device}'")
if device == "mps": if device in ("mps", "mps:0") and "dev" not in torch.__version__:
eval_logger.info( eval_logger.info(
"MPS is still in beta and only supports float32; setting dtype to float32." "MPS: Setting dtype to float32. To use float16 with MPS, please install a nightly build of "
"PyTorch: pip3 install --pre torch torchvision torchaudio --index-url "
"https://download.pytorch.org/whl/nightly/cpu"
) )
else: else:
eval_logger.info("Device not specified") eval_logger.info("Device not specified")
......
import ast
from typing import Dict
from lm_eval import utils from lm_eval import utils
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
...@@ -5,7 +8,7 @@ from lm_eval.logger import eval_logger ...@@ -5,7 +8,7 @@ from lm_eval.logger import eval_logger
# Stores prompts in a dictionary indexed by 2 levels: # Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name. # prompt category name, and prompt name.
# This allows us to access prompts # This allows us to access prompts
PROMPT_REGISTRY: dict[str, dict[str, str]] = { PROMPT_REGISTRY: Dict[str, Dict[str, str]] = {
"qa-basic": { "qa-basic": {
"question-newline-answer": "Question: {{question}}\nAnswer:", "question-newline-answer": "Question: {{question}}\nAnswer:",
"q-newline-a": "Q: {{question}}\nA:", "q-newline-a": "Q: {{question}}\nA:",
...@@ -63,6 +66,12 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwa ...@@ -63,6 +66,12 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwa
else: else:
prompts = DatasetTemplates(dataset_name=dataset_name, subset_name=subset_name) prompts = DatasetTemplates(dataset_name=dataset_name, subset_name=subset_name)
category_name, prompt_name = use_prompt.split(":") category_name, *prompt_name = use_prompt.split(":")
# TODO allow to multiple prompt naming
# if len(prompt_name) > 1:
# prompt_list = []
# for prompt in prompt_name:
# prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names))
# else:
prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names) prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
return [":".join([category_name, prompt]) for prompt in prompt_list] return [":".join([category_name, prompt]) for prompt in prompt_list]
...@@ -16,7 +16,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -16,7 +16,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] MCTACO - [x] MCTACO
- [x] Pubmed QA - [x] Pubmed QA
- [x] SciQ - [x] SciQ
- [ ] QASPER - [x] QASPER
- [x] QA4MRE - [x] QA4MRE
- [x] TriviaQA - [x] TriviaQA
- [x] AI2 ARC - [x] AI2 ARC
...@@ -36,7 +36,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -36,7 +36,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] TruthfulQA (mc1) - [x] TruthfulQA (mc1)
- [x] TruthfulQA (mc2) - [x] TruthfulQA (mc2)
- [x] TruthfulQA (gen) - [x] TruthfulQA (gen)
- [ ] MuTual - [x] MuTual
- [ ] Hendrycks Math (Hailey) - [ ] Hendrycks Math (Hailey)
- [x] Asdiv - [x] Asdiv
- [ ] GSM8k - [ ] GSM8k
......
import os import os
import yaml import yaml
from typing import List, Union from typing import List, Union, Dict
from lm_eval import utils from lm_eval import utils
from lm_eval import prompts from lm_eval import prompts
...@@ -15,7 +15,7 @@ from lm_eval.api.registry import ( ...@@ -15,7 +15,7 @@ from lm_eval.api.registry import (
) )
def register_configurable_task(config: dict[str, str]) -> int: def register_configurable_task(config: Dict[str, str]) -> int:
SubClass = type( SubClass = type(
config["task"] + "ConfigurableTask", config["task"] + "ConfigurableTask",
(ConfigurableTask,), (ConfigurableTask,),
...@@ -38,7 +38,35 @@ def register_configurable_task(config: dict[str, str]) -> int: ...@@ -38,7 +38,35 @@ def register_configurable_task(config: dict[str, str]) -> int:
return 0 return 0
def check_prompt_config(config: dict[str, str]) -> List[dict[str, str]]: def register_configurable_group(config: Dict[str, str]) -> int:
group = config["group"]
all_task_list = config["task"]
config_list = [task for task in all_task_list if type(task) != str]
task_list = [task for task in all_task_list if type(task) == str]
for task_config in config_list:
var_configs = check_prompt_config(
{
**task_config,
**{"group": group},
}
)
for config in var_configs:
register_configurable_task(config)
task_names = utils.pattern_match(task_list, ALL_TASKS)
for task in task_names:
if (task in TASK_REGISTRY) or (task in GROUP_REGISTRY):
if group in GROUP_REGISTRY:
GROUP_REGISTRY[group].append(task)
else:
GROUP_REGISTRY[group] = [task]
ALL_TASKS.add(group)
return 0
def check_prompt_config(config: Dict[str, str]) -> List[Dict[str, str]]:
all_configs = [] all_configs = []
if "use_prompt" in config: if "use_prompt" in config:
prompt_list = prompts.load_prompt_list( prompt_list = prompts.load_prompt_list(
...@@ -69,14 +97,14 @@ def check_prompt_config(config: dict[str, str]) -> List[dict[str, str]]: ...@@ -69,14 +97,14 @@ def check_prompt_config(config: dict[str, str]) -> List[dict[str, str]]:
return all_configs return all_configs
def get_task_name_from_config(task_config: dict[str, str]) -> str: def get_task_name_from_config(task_config: Dict[str, str]) -> str:
if "dataset_name" in task_config: if "dataset_name" in task_config:
return "{dataset_path}_{dataset_name}".format(**task_config) return "{dataset_path}_{dataset_name}".format(**task_config)
else: else:
return "{dataset_path}".format(**task_config) return "{dataset_path}".format(**task_config)
def include_task_folder(task_dir: str) -> None: def include_task_folder(task_dir: str, register_task=True) -> None:
""" """
Calling this function Calling this function
""" """
...@@ -87,9 +115,16 @@ def include_task_folder(task_dir: str) -> None: ...@@ -87,9 +115,16 @@ def include_task_folder(task_dir: str) -> None:
yaml_path = os.path.join(root, f) yaml_path = os.path.join(root, f)
try: try:
config = utils.load_yaml_config(yaml_path) config = utils.load_yaml_config(yaml_path)
if register_task:
all_configs = check_prompt_config(config) all_configs = check_prompt_config(config)
for config in all_configs: for config in all_configs:
register_configurable_task(config) register_configurable_task(config)
else:
# If a `task` in config is a list,
# that means it's a benchmark
if type(config["task"]) == list:
register_configurable_group(config)
except Exception as error: except Exception as error:
eval_logger.warning( eval_logger.warning(
...@@ -102,6 +137,8 @@ def include_task_folder(task_dir: str) -> None: ...@@ -102,6 +137,8 @@ def include_task_folder(task_dir: str) -> None:
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/" task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_task_folder(task_dir) include_task_folder(task_dir)
# Register Benchmarks after all tasks have been added
include_task_folder(task_dir, register_task=False)
def get_task(task_name, config): def get_task(task_name, config):
...@@ -128,7 +165,7 @@ def get_task_name_from_object(task_object): ...@@ -128,7 +165,7 @@ def get_task_name_from_object(task_object):
# TODO: pass num_fewshot and other cmdline overrides in a better way # TODO: pass num_fewshot and other cmdline overrides in a better way
def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs): def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
config = {**kwargs} config = {**kwargs}
...@@ -136,6 +173,9 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs): ...@@ -136,6 +173,9 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
task_name_from_config_dict = {} task_name_from_config_dict = {}
task_name_from_object_dict = {} task_name_from_object_dict = {}
if type(task_name_list) != list:
task_name_list = [task_name_list]
for task_element in task_name_list: for task_element in task_name_list:
if isinstance(task_element, str): if isinstance(task_element, str):
...@@ -143,12 +183,20 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs): ...@@ -143,12 +183,20 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
group_name = task_element group_name = task_element
for task_name in GROUP_REGISTRY[task_element]: for task_name in GROUP_REGISTRY[task_element]:
if task_name not in task_name_from_registry_dict: if task_name not in task_name_from_registry_dict:
task_obj = get_task_dict(task_name)
if task_name in task_obj.keys():
task_dict = {
task_name: (group_name, task_obj[task_name]),
}
else:
task_dict = {
task_name: (group_name, None),
**task_obj,
}
task_name_from_registry_dict = { task_name_from_registry_dict = {
**task_name_from_registry_dict, **task_name_from_registry_dict,
task_name: ( **task_dict,
group_name,
get_task(task_name=task_name, config=config),
),
} }
else: else:
task_name = task_element task_name = task_element
......
group: pythia group: pythia
task: task:
- lambada_openai - lambada_openai
- wikitext - logiqa
- piqa - piqa
- sciq - sciq
- wsc - wikitext
- winogrande - winogrande
- arc - wsc
- logiqa - ai2_arc
- blimp - blimp
- hendrycksTest* - hendrycksTest*
# Generated by utils.py
dataset_name: bn
doc_to_target: '{% if answer is not none %}{{answer[16+1]}}{% else %}{{answer_number|string}}{%
endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nধাপে ধাপে উত্তর:"}}{% else
%}{{"প্রশ্ন: "+question+"\nধাপে ধাপে উত্তর:"}}{% endif %}'
include: cot_yaml
task: mgsm_bn_direct
# Generated by utils.py
dataset_name: de
doc_to_target: '{% if answer is not none %}{{answer[28+1]}}{% else %}{{answer_number|string}}{%
endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nSchritt-für-Schritt-Antwort:"}}{%
else %}{{"Frage: "+question+"\nSchritt-für-Schritt-Antwort:"}}{% endif %}'
include: cot_yaml
task: mgsm_de_direct
# Generated by utils.py
dataset_name: en
doc_to_target: '{% if answer is not none %}{{answer[20+1]}}{% else %}{{answer_number|string}}{%
endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nStep-by-Step Answer:"}}{% else
%}{{"Question: "+question+"\nStep-by-Step Answer:"}}{% endif %}'
include: cot_yaml
task: mgsm_en_direct
# Generated by utils.py
dataset_name: es
doc_to_target: '{% if answer is not none %}{{answer[22+1]}}{% else %}{{answer_number|string}}{%
endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nRespuesta paso a paso:"}}{%
else %}{{"Pregunta: "+question+"\nRespuesta paso a paso:"}}{% endif %}'
include: cot_yaml
task: mgsm_es_direct
# Generated by utils.py
dataset_name: fr
doc_to_target: '{% if answer is not none %}{{answer[25+1]}}{% else %}{{answer_number|string}}{%
endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nRéponse étape par étape :"}}{%
else %}{{"Question : "+question+"\nRéponse étape par étape :"}}{% endif %}'
include: cot_yaml
task: mgsm_fr_direct
# Generated by utils.py
dataset_name: ja
doc_to_target: '{% if answer is not none %}{{answer[10+1]}}{% else %}{{answer_number|string}}{%
endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nステップごとの答え:"}}{% else %}{{"問題:
"+question+"\nステップごとの答え:"}}{% endif %}'
include: cot_yaml
task: mgsm_ja_direct
# Generated by utils.py
dataset_name: ru
doc_to_target: '{% if answer is not none %}{{answer[17+1]}}{% else %}{{answer_number|string}}{%
endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nПошаговоерешение:"}}{% else
%}{{"Задача: "+question+"\nПошаговоерешение:"}}{% endif %}'
include: cot_yaml
task: mgsm_ru_direct
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment