Commit 50e99bd7 authored by Herbie Bradley's avatar Herbie Bradley
Browse files

Merge remote-tracking branch 'origin/big-refactor' into calibration

parents 3d4c4cd6 a3252ed7
name: Tasks Modified
# name: Tasks Modified
on:
push:
branches:
- 'big-refactor*'
pull_request:
branches:
- 'big-refactor*'
workflow_dispatch:
# comment/edit out the above to stop/change the triggers
jobs:
changed_files:
runs-on: ubuntu-latest # windows-latest || macos-latest
timeout-minutes: 120
name: Scan for changed tasks
steps:
- name: checkout
uses: actions/checkout@v3
with:
fetch-depth: 2 # OR "2" -> To retrieve the preceding commit.
# on:
# push:
# branches:
# - 'big-refactor*'
# pull_request:
# branches:
# - 'big-refactor*'
# workflow_dispatch:
# # comment/edit out the above to stop/change the triggers
# jobs:
# changed_files:
# runs-on: ubuntu-latest # windows-latest || macos-latest
# timeout-minutes: 120
# name: Scan for changed tasks
# steps:
# - name: checkout
# uses: actions/checkout@v3
# with:
# fetch-depth: 2 # OR "2" -> To retrieve the preceding commit.
# Uses the tj-actions/changed-files@v37 action to check for changes.
# Outputs provided here: https://github.com/tj-actions/changed-files#outputs
# The `files_yaml` input optionally takes a yaml string to specify filters,
# and prepends the filter name to the standard output names.
- name: Check task folders
id: changed-tasks
uses: tj-actions/changed-files@v37.1.2
with:
# tasks checks the tasks folder and api checks the api folder for changes
files_yaml: |
tasks:
- lm_eval/tasks/**
api:
- lm_eval/api/**
write_output_files: true
# # Uses the tj-actions/changed-files@v37 action to check for changes.
# # Outputs provided here: https://github.com/tj-actions/changed-files#outputs
# # The `files_yaml` input optionally takes a yaml string to specify filters,
# # and prepends the filter name to the standard output names.
# - name: Check task folders
# id: changed-tasks
# uses: tj-actions/changed-files@v37.1.2
# with:
# # tasks checks the tasks folder and api checks the api folder for changes
# files_yaml: |
# tasks:
# - lm_eval/tasks/**
# api:
# - lm_eval/api/**
# write_output_files: true
# The next step is optional; the files are written to the workspace by default (above).
# so it's just for debugging
- name: Run Tests
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
run: |
echo .github/outputs/tasks_all_changed_and_modified_files.txt >> 'GITHUB_ENV'
echo "One or more test file(s) has changed."
echo "List of all the files that have changed: ${{ steps.changed-tasks.outputs.tasks_all_modified_files }}"
# # The next step is optional; the files are written to the workspace by default (above).
# # so it's just for debugging
# - name: Run Tests
# if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
# run: |
# echo .github/outputs/tasks_all_changed_and_modified_files.txt >> 'GITHUB_ENV'
# echo "One or more test file(s) has changed."
# echo "List of all the files that have changed: ${{ steps.changed-tasks.outputs.tasks_all_modified_files }}"
- name: Set up Python 3.9
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
uses: actions/setup-python@v4
with:
python-version: 3.9
cache: 'pip'
cache-dependency-path: setup.py
- name: Install dependencies
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
run: |
python -m pip install --upgrade pip
pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
# if new tasks are added, run tests on them
if: steps.changed-tasks.outputs.tasks_any_modified == 'true'
run: python -m pytest tests/test_tasks.py -s -vv -n=auto
# if api is modified, run tests on it
- name: Test more tasks with pytest
env:
API: true
if: steps.changed-tasks.outputs.api_any_modified == 'true'
run: python -m pytest tests/test_tasks.py -s -vv -n=auto
# - name: Set up Python 3.9
# if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
# uses: actions/setup-python@v4
# with:
# python-version: 3.9
# cache: 'pip'
# cache-dependency-path: setup.py
# - name: Install dependencies
# if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
# run: |
# python -m pip install --upgrade pip
# pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
# # Install optional git dependencies
# # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
# - name: Test with pytest
# # if new tasks are added, run tests on them
# if: steps.changed-tasks.outputs.tasks_any_modified == 'true'
# run: python -m pytest tests/test_tasks.py -s -vv
# # if api is modified, run tests on it
# - name: Test more tasks with pytest
# env:
# API: true
# if: steps.changed-tasks.outputs.api_any_modified == 'true'
# run: python -m pytest tests/test_tasks.py -s -vv
......@@ -40,39 +40,38 @@ jobs:
flake8 . --count --select=F,E9,E71,E72,E501,E112,E113,W6 --extend-ignore=F541 --show-source --statistics --exit-zero
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
# mypy turned off for now
# # mypy turned off for now
# - name: Lint with mypy
# run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable
# Job 2
testcpu:
name: CPU Tests
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ "3.9", "3.10", "3.11" ]
timeout-minutes: 30
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: setup.py
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/tests_master --ignore=tests/extra
- name: Archive artifacts
uses: actions/upload-artifact@v3
with:
name: output_results
path: |
test_logs/*
# testcpu:
# name: CPU Tests
# runs-on: ubuntu-latest
# strategy:
# matrix:
# python-version: [ "3.8", "3.9", "3.10", "3.11" ]
# timeout-minutes: 30
# steps:
# - name: Checkout Code
# uses: actions/checkout@v3
# - name: Set up Python ${{ matrix.python-version }}
# uses: actions/setup-python@v4
# with:
# python-version: ${{ matrix.python-version }}
# cache: pip
# cache-dependency-path: setup.py
# - name: Install dependencies
# run: |
# python -m pip install --upgrade pip
# pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
# # Install optional git dependencies
# # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
# - name: Test with pytest
# run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/tests_master --ignore=tests/extra
# - name: Archive artifacts
# uses: actions/upload-artifact@v3
# with:
# name: output_results
# path: |
# test_logs/*
import abc
import os
from typing import Union, List, Tuple
import torch
from typing import Union, List, Tuple, Optional, Type, TypeVar
from sqlitedict import SqliteDict
import json
import hashlib
......@@ -11,6 +12,8 @@ from tqdm import tqdm
from lm_eval import utils
from lm_eval.logger import eval_logger
T = TypeVar("T", bound="LM")
class LM(abc.ABC):
def __init__(self) -> None:
......@@ -111,11 +114,28 @@ class LM(abc.ABC):
pass
@classmethod
def create_from_arg_string(cls, arg_string, additional_config=None):
def create_from_arg_string(
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
) -> T:
"""
Creates an instance of the LM class using the given argument string and additional config.
Parameters:
- arg_string: A string containing arguments in the format key1=value1,key2=value2.
- additional_config: Optional dictionary containing additional configuration parameters.
Returns:
- Instance of the LM class.
"""
additional_config = {} if additional_config is None else additional_config
args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None}
if args2.get("device") == "mps" or args.get("device") == "mps":
# TODO: delete once float16 MPS is fixed in torch stable
if (
args2.get("device") in ("mps", "mps:0")
or args.get("device") in ("mps", "mps:0")
and "dev" not in torch.__version__
):
args["dtype"] = "float32"
return cls(**args, **args2)
......
......@@ -674,22 +674,22 @@ class ConfigurableTask(Task):
check_choices = test_choice
else:
check_choices = [test_target]
for choice in check_choices:
choice_has_whitespace = True if " " in choice else False
delimiter_has_whitespace = (
True if " " in self.config.target_delimiter else False
)
if delimiter_has_whitespace and choice_has_whitespace:
eval_logger.warning(
f'Both target_delimiter and target choice: "{choice}" have whitespace'
)
elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
eval_logger.warning(
f'Both target_delimiter and target choice: "{choice}" does not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
if self.config.doc_to_choice is not None:
for choice in check_choices:
choice_has_whitespace = True if choice[0].isspace() else False
delimiter_has_whitespace = (
True if self.config.target_delimiter[-1].isspace() else False
)
if delimiter_has_whitespace and choice_has_whitespace:
eval_logger.warning(
f'Both target_delimiter and target choice: "{choice}" have whitespace'
)
elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
eval_logger.warning(
f'Both target_delimiter and target choice: "{choice}" does not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
)
def download(self, dataset_kwargs=None) -> None:
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
......@@ -1080,6 +1080,9 @@ class ConfigurableTask(Task):
# it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc)
gold = choices[gold]
# we expect multiple_targets to be a list.
elif self.multiple_target:
gold = list(gold)
else:
gold = str(gold)
......@@ -1090,6 +1093,10 @@ class ConfigurableTask(Task):
# return true if any are true
# TODO: this may break for multipLe_target, non zero-or-1 metrics
scores = []
if not isinstance(gold, list):
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
# print(gold)
gold = [gold]
for gold_option in gold:
try:
result_score = self._metric_fn_list[metric](
......
import os
import yaml
from lm_eval import utils
from lm_eval.tasks import register_configurable_task, check_prompt_config
from lm_eval.logger import eval_logger
from lm_eval.api.registry import (
TASK_REGISTRY,
GROUP_REGISTRY,
ALL_TASKS,
)
def include_benchmarks(task_dir: str) -> None:
for root, subdirs, file_list in os.walk(task_dir):
if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
for f in file_list:
if f.endswith(".yaml"):
try:
benchmark_path = os.path.join(root, f)
with open(benchmark_path, "rb") as file:
yaml_config = yaml.full_load(file)
assert "group" in yaml_config
group = yaml_config["group"]
all_task_list = yaml_config["task"]
config_list = [
task for task in all_task_list if type(task) != str
]
task_list = [
task for task in all_task_list if type(task) == str
]
for task_config in config_list:
var_configs = check_prompt_config(
{
**task_config,
**{"group": group},
}
)
for config in var_configs:
register_configurable_task(config)
task_names = utils.pattern_match(task_list, ALL_TASKS)
for task in task_names:
if task in TASK_REGISTRY:
if group in GROUP_REGISTRY:
GROUP_REGISTRY[group].append(task)
else:
GROUP_REGISTRY[group] = [task]
ALL_TASKS.add(group)
except Exception as error:
eval_logger.warning(
"Failed to load benchmark in\n"
f" {benchmark_path}\n"
" Benchmark will not be added to registry\n"
f" Error: {error}"
)
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_benchmarks(task_dir)
......@@ -3,7 +3,7 @@ import string
import pickle
import traceback
from pprint import pprint
from typing import Iterator, Sequence, TypeVar
from typing import Iterator, Sequence, TypeVar, List, Tuple
# This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
......@@ -21,7 +21,7 @@ T = TypeVar("T")
# Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html
def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[tuple[T, ...]]:
def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[Tuple[T, ...]]:
history = []
while n > 1:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
......@@ -70,14 +70,14 @@ def word_ngrams(s: str, n: int) -> Iterator[str]:
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
def split_indices(s: str) -> Iterator[tuple[str, tuple[int, int]]]:
def split_indices(s: str) -> Iterator[Tuple[str, Tuple[int, int]]]:
"""Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...)
"""
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
def word_ngrams_indices(s: str, n: int) -> Iterator[tuple[str, tuple[int, int]]]:
def word_ngrams_indices(s: str, n: int) -> Iterator[Tuple[str, Tuple[int, int]]]:
"""Splits a string into pairs of (ngram words, their start/end indices)"""
tokens_with_indices = split_indices(s)
......@@ -157,7 +157,7 @@ class Janitor:
print("WARNING: Janitor running in python mode")
return self.register_contaminant_python(dirt_string)
def clean(self, dirty_string: str) -> list[str]:
def clean(self, dirty_string: str) -> List[str]:
"""Clean a string (e.g. a training set) by removing all ngrams previously
registered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty"""
......@@ -168,8 +168,8 @@ class Janitor:
return self.clean_python(dirty_string)
def _split_chunks(
self, dirty_string: str, dirty_parts: Sequence[tuple]
) -> list[str]:
self, dirty_string: str, dirty_parts: Sequence[Tuple]
) -> List[str]:
clean_chunks = []
splice_idx = 0
end = -1
......@@ -197,7 +197,7 @@ class Janitor:
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
)
def clean_cpp(self, dirty_string: str) -> list[str]:
def clean_cpp(self, dirty_string: str) -> List[str]:
contamination_indices = janitor_util.clean_ngram_with_indices(
dirty_string, self.delete_chars, self.ngram_n
)
......@@ -215,7 +215,7 @@ class Janitor:
word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
)
def clean_python(self, dirty_string: str) -> list[str]:
def clean_python(self, dirty_string: str) -> List[str]:
contamination_indices = (
(None, *idx_pair)
for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
......
......@@ -118,6 +118,8 @@ def simple_evaluate(
task_obj = task_dict[task_name]
if type(task_obj) == tuple:
group, task_obj = task_obj
if task_obj is None:
continue
config = task_obj._config
if num_fewshot is not None:
......@@ -207,23 +209,30 @@ def evaluate(
samples = collections.defaultdict(list)
# tracks all Instances/requests a model must generate output on.
requests = collections.defaultdict(list)
# Stores task scores based on task grouping.
aggregate = collections.defaultdict(dict)
# tracks if a task was chosen via user selecting a group containing it
task_groups = collections.defaultdict(dict)
# Aggregated task scores presented with groups
results_agg = collections.defaultdict(dict)
# Aggregated groups scores only
groups_agg = collections.defaultdict(dict)
# stores the amount to pad out reqs per req. type so that
# number of fwd passes per distributed rank is equal
padding_requests = collections.defaultdict(int)
# Stores group related keys and values for group-aggregation
task_groups = collections.defaultdict(dict)
# store the hierarchy to do proper ordering
task_hierarchy = collections.defaultdict(list)
# store the ordering of tasks and groups
task_order = collections.defaultdict(int)
# store the aggregation for aggregating across tasks in the same group
sample_agg_fn = collections.defaultdict(dict)
# get lists of each type of request
for task_name, task in task_dict.items():
if type(task) == tuple:
group, task = task
task_groups[task_name] = group
aggregate[task_name] = {}
group_name, task = task
task_hierarchy[group_name].append(task_name)
else:
task_hierarchy[task_name] = []
if task is None:
continue
versions[task_name] = task.VERSION
configs[task_name] = dict(task.dump_config())
......@@ -299,6 +308,8 @@ def evaluate(
for task_name, task in task_dict.items():
if type(task) == tuple:
group, task = task
if task is None:
continue
task.apply_filters()
### Collect values of metrics on all datapoints ###
......@@ -308,6 +319,8 @@ def evaluate(
for task_name, task in task_dict.items():
if type(task) == tuple:
group, task = task
if task is None:
continue
# TODO: make it possible to use a different metric per filter
# iterate over different filters used
for key in task.instances[0].filtered_resps.keys():
......@@ -468,27 +481,62 @@ def evaluate(
vals = vals_torch
if lm.rank == 0:
### Get task ordering for correct sample-wide aggregation
group_to_task = {}
for group in task_hierarchy.keys():
if group not in task_order:
task_order[group] = 0
if len(task_hierarchy[group]) > 0:
group_to_task[group] = task_hierarchy[group].copy()
for task in task_hierarchy[group]:
if task in task_order:
task_order[task] += 1
else:
task_order[task] = 1 + task_order[group]
if task in task_hierarchy:
group_to_task[group].remove(task)
group_to_task[group].extend(task_hierarchy[task])
task_to_group = {}
for group in group_to_task:
for task in group_to_task[group]:
if task in task_to_group:
task_to_group[task].append(group)
else:
task_to_group[task] = [group]
### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items():
task = task_dict[task_name]
metric_key = metric + "," + key
if type(task) == tuple:
group, task = task
task_score = task.aggregation()[metric](items)
results[task_name][metric + "," + key] = task_score
# Need to put back in results
# pythia | acc
# | perplexity
# | word_perplexity
# | byte_perplexity
# | bits_per_byte
if task_name in task_groups:
group_name = task_groups[task_name]
if metric in list(aggregate[group_name].keys()):
aggregate[group_name][metric].append(task_score)
else:
aggregate[group_name][metric] = [task_score]
group_name, task = task
else:
group_name = None
agg_fn = task.aggregation()[metric]
task_score = agg_fn(items)
if group_name is not None:
sample_metric_key = metric + "(sample agg)," + key
for grouping in task_to_group[task_name]:
if metric_key in results[grouping]:
results[grouping][metric_key].append(task_score)
else:
results[grouping][metric_key] = [task_score]
if sample_metric_key in results[grouping]:
results[grouping][sample_metric_key] += items
else:
results[grouping][sample_metric_key] = items.copy()
sample_agg_fn[grouping][sample_metric_key] = agg_fn
results[task_name][metric_key] = task_score
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
......@@ -503,19 +551,38 @@ def evaluate(
if stderr is not None:
results[task_name][metric + "_stderr" + "," + key] = stderr(items)
if bool(aggregate):
for group in aggregate.keys():
for metric in aggregate[group].keys():
aggregate[group][metric] = np.average(aggregate[group][metric])
versions[group] = "N/A"
if bool(results):
for task_or_group in results.keys():
for metric in results[task_or_group].keys():
if type(results[task_or_group][metric]) == list:
if "(sample agg)" in metric:
results[task_or_group][metric] = sample_agg_fn[
task_or_group
][metric](results[task_or_group][metric])
else:
results[task_or_group][metric] = np.average(
results[task_or_group][metric]
)
versions[task_or_group] = "N/A"
for task_name, task in task_dict.items():
if type(task) == tuple:
group_name, task = task
order = task_order[group_name]
tabbed_name = "-" * order + group_name
results_agg[tabbed_name] = results[group_name]
versions[tabbed_name] = versions[group_name]
if order == 0:
groups_agg[group_name] = results[group_name]
order = task_order[task_name]
tabbed_name = "-" * order + task_name
results_agg[tabbed_name] = results[task_name]
versions[tabbed_name] = versions[task_name]
results_dict = {
"results": dict(sorted(results.items())),
**(
{"aggregate": dict(sorted(aggregate.items()))}
if bool(aggregate)
else {}
),
"results": dict(results_agg.items()),
**({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
"configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())),
}
......
......@@ -101,17 +101,20 @@ class HFLM(LM):
if not (parallelize or accelerator.num_processes > 1):
# use user-passed device
device_list = set(
["cuda", "cpu", "mps"]
["cuda", "cpu"]
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())]
+ ["mps", "mps:0"]
)
if device:
if device not in device_list:
device = int(device)
self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'")
if device == "mps":
if device in ("mps", "mps:0") and "dev" not in torch.__version__:
eval_logger.info(
"MPS is still in beta and only supports float32; setting dtype to float32."
"MPS: Setting dtype to float32. To use float16 with MPS, please install a nightly build of "
"PyTorch: pip3 install --pre torch torchvision torchaudio --index-url "
"https://download.pytorch.org/whl/nightly/cpu"
)
else:
eval_logger.info("Device not specified")
......
import ast
from typing import Dict
from lm_eval import utils
from lm_eval.logger import eval_logger
......@@ -5,7 +8,7 @@ from lm_eval.logger import eval_logger
# Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name.
# This allows us to access prompts
PROMPT_REGISTRY: dict[str, dict[str, str]] = {
PROMPT_REGISTRY: Dict[str, Dict[str, str]] = {
"qa-basic": {
"question-newline-answer": "Question: {{question}}\nAnswer:",
"q-newline-a": "Q: {{question}}\nA:",
......@@ -63,6 +66,12 @@ def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwa
else:
prompts = DatasetTemplates(dataset_name=dataset_name, subset_name=subset_name)
category_name, prompt_name = use_prompt.split(":")
category_name, *prompt_name = use_prompt.split(":")
# TODO allow to multiple prompt naming
# if len(prompt_name) > 1:
# prompt_list = []
# for prompt in prompt_name:
# prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names))
# else:
prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
return [":".join([category_name, prompt]) for prompt in prompt_list]
......@@ -16,7 +16,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] MCTACO
- [x] Pubmed QA
- [x] SciQ
- [ ] QASPER
- [x] QASPER
- [x] QA4MRE
- [x] TriviaQA
- [x] AI2 ARC
......@@ -36,7 +36,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] TruthfulQA (mc1)
- [x] TruthfulQA (mc2)
- [x] TruthfulQA (gen)
- [ ] MuTual
- [x] MuTual
- [ ] Hendrycks Math (Hailey)
- [x] Asdiv
- [ ] GSM8k
......
import os
import yaml
from typing import List, Union
from typing import List, Union, Dict
from lm_eval import utils
from lm_eval import prompts
......@@ -15,7 +15,7 @@ from lm_eval.api.registry import (
)
def register_configurable_task(config: dict[str, str]) -> int:
def register_configurable_task(config: Dict[str, str]) -> int:
SubClass = type(
config["task"] + "ConfigurableTask",
(ConfigurableTask,),
......@@ -38,7 +38,35 @@ def register_configurable_task(config: dict[str, str]) -> int:
return 0
def check_prompt_config(config: dict[str, str]) -> List[dict[str, str]]:
def register_configurable_group(config: Dict[str, str]) -> int:
group = config["group"]
all_task_list = config["task"]
config_list = [task for task in all_task_list if type(task) != str]
task_list = [task for task in all_task_list if type(task) == str]
for task_config in config_list:
var_configs = check_prompt_config(
{
**task_config,
**{"group": group},
}
)
for config in var_configs:
register_configurable_task(config)
task_names = utils.pattern_match(task_list, ALL_TASKS)
for task in task_names:
if (task in TASK_REGISTRY) or (task in GROUP_REGISTRY):
if group in GROUP_REGISTRY:
GROUP_REGISTRY[group].append(task)
else:
GROUP_REGISTRY[group] = [task]
ALL_TASKS.add(group)
return 0
def check_prompt_config(config: Dict[str, str]) -> List[Dict[str, str]]:
all_configs = []
if "use_prompt" in config:
prompt_list = prompts.load_prompt_list(
......@@ -69,14 +97,14 @@ def check_prompt_config(config: dict[str, str]) -> List[dict[str, str]]:
return all_configs
def get_task_name_from_config(task_config: dict[str, str]) -> str:
def get_task_name_from_config(task_config: Dict[str, str]) -> str:
if "dataset_name" in task_config:
return "{dataset_path}_{dataset_name}".format(**task_config)
else:
return "{dataset_path}".format(**task_config)
def include_task_folder(task_dir: str) -> None:
def include_task_folder(task_dir: str, register_task=True) -> None:
"""
Calling this function
"""
......@@ -87,9 +115,16 @@ def include_task_folder(task_dir: str) -> None:
yaml_path = os.path.join(root, f)
try:
config = utils.load_yaml_config(yaml_path)
all_configs = check_prompt_config(config)
for config in all_configs:
register_configurable_task(config)
if register_task:
all_configs = check_prompt_config(config)
for config in all_configs:
register_configurable_task(config)
else:
# If a `task` in config is a list,
# that means it's a benchmark
if type(config["task"]) == list:
register_configurable_group(config)
except Exception as error:
eval_logger.warning(
......@@ -102,6 +137,8 @@ def include_task_folder(task_dir: str) -> None:
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_task_folder(task_dir)
# Register Benchmarks after all tasks have been added
include_task_folder(task_dir, register_task=False)
def get_task(task_name, config):
......@@ -128,7 +165,7 @@ def get_task_name_from_object(task_object):
# TODO: pass num_fewshot and other cmdline overrides in a better way
def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
config = {**kwargs}
......@@ -136,6 +173,9 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
task_name_from_config_dict = {}
task_name_from_object_dict = {}
if type(task_name_list) != list:
task_name_list = [task_name_list]
for task_element in task_name_list:
if isinstance(task_element, str):
......@@ -143,12 +183,20 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
group_name = task_element
for task_name in GROUP_REGISTRY[task_element]:
if task_name not in task_name_from_registry_dict:
task_obj = get_task_dict(task_name)
if task_name in task_obj.keys():
task_dict = {
task_name: (group_name, task_obj[task_name]),
}
else:
task_dict = {
task_name: (group_name, None),
**task_obj,
}
task_name_from_registry_dict = {
**task_name_from_registry_dict,
task_name: (
group_name,
get_task(task_name=task_name, config=config),
),
**task_dict,
}
else:
task_name = task_element
......
group: pythia
task:
- lambada_openai
- wikitext
- logiqa
- piqa
- sciq
- wsc
- wikitext
- winogrande
- arc
- logiqa
- wsc
- ai2_arc
- blimp
- hendrycksTest*
# Generated by utils.py
dataset_name: bn
doc_to_target: '{% if answer is not none %}{{answer[16+1]}}{% else %}{{answer_number|string}}{%
endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nধাপে ধাপে উত্তর:"}}{% else
%}{{"প্রশ্ন: "+question+"\nধাপে ধাপে উত্তর:"}}{% endif %}'
include: cot_yaml
task: mgsm_bn_direct
# Generated by utils.py
dataset_name: de
doc_to_target: '{% if answer is not none %}{{answer[28+1]}}{% else %}{{answer_number|string}}{%
endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nSchritt-für-Schritt-Antwort:"}}{%
else %}{{"Frage: "+question+"\nSchritt-für-Schritt-Antwort:"}}{% endif %}'
include: cot_yaml
task: mgsm_de_direct
# Generated by utils.py
dataset_name: en
doc_to_target: '{% if answer is not none %}{{answer[20+1]}}{% else %}{{answer_number|string}}{%
endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nStep-by-Step Answer:"}}{% else
%}{{"Question: "+question+"\nStep-by-Step Answer:"}}{% endif %}'
include: cot_yaml
task: mgsm_en_direct
# Generated by utils.py
dataset_name: es
doc_to_target: '{% if answer is not none %}{{answer[22+1]}}{% else %}{{answer_number|string}}{%
endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nRespuesta paso a paso:"}}{%
else %}{{"Pregunta: "+question+"\nRespuesta paso a paso:"}}{% endif %}'
include: cot_yaml
task: mgsm_es_direct
# Generated by utils.py
dataset_name: fr
doc_to_target: '{% if answer is not none %}{{answer[25+1]}}{% else %}{{answer_number|string}}{%
endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nRéponse étape par étape :"}}{%
else %}{{"Question : "+question+"\nRéponse étape par étape :"}}{% endif %}'
include: cot_yaml
task: mgsm_fr_direct
# Generated by utils.py
dataset_name: ja
doc_to_target: '{% if answer is not none %}{{answer[10+1]}}{% else %}{{answer_number|string}}{%
endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nステップごとの答え:"}}{% else %}{{"問題:
"+question+"\nステップごとの答え:"}}{% endif %}'
include: cot_yaml
task: mgsm_ja_direct
# Generated by utils.py
dataset_name: ru
doc_to_target: '{% if answer is not none %}{{answer[17+1]}}{% else %}{{answer_number|string}}{%
endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nПошаговоерешение:"}}{% else
%}{{"Задача: "+question+"\nПошаговоерешение:"}}{% endif %}'
include: cot_yaml
task: mgsm_ru_direct
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment