Unverified Commit cda25fef authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge branch 'main' into standardize_metrics

parents dfb41835 4d10ad56
......@@ -33,4 +33,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
- version: 2.0
version: 2.0
......@@ -10,4 +10,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
- version: 2.0
version: 2.0
......@@ -6,7 +6,6 @@ from rouge_score import rouge_scorer, scoring
def process_results_mc2(doc, results):
lls, is_greedy = zip(*results)
# Split on the first `0` as everything before it is true (`1`).
......@@ -20,7 +19,6 @@ def process_results_mc2(doc, results):
def process_docs_gen(dataset: datasets.Dataset) -> datasets.Dataset:
return dataset.map(preprocess_function)
......@@ -49,7 +47,6 @@ def preprocess_function(examples):
def process_results_gen(doc, results):
completion = results[0]
true_refs, false_refs = doc["correct_answers"], doc["incorrect_answers"]
all_refs = true_refs + false_refs
......
......@@ -17,4 +17,4 @@ metric_list:
ignore_case: false
ignore_punctuation: false
metadata:
- version: 1.0
version: 1.0
......@@ -17,4 +17,4 @@ metric_list:
ignore_case: false
ignore_punctuation: false
metadata:
- version: 1.0
version: 1.0
......@@ -17,4 +17,4 @@ metric_list:
ignore_case: false
ignore_punctuation: false
metadata:
- version: 1.0
version: 1.0
......@@ -17,4 +17,4 @@ metric_list:
ignore_case: false
ignore_punctuation: false
metadata:
- version: 1.0
version: 1.0
......@@ -17,4 +17,4 @@ metric_list:
ignore_case: false
ignore_punctuation: false
metadata:
- version: 1.0
version: 1.0
......@@ -17,4 +17,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
- version: 1.0
version: 1.0
......@@ -15,4 +15,4 @@ metric_list:
- metric: byte_perplexity
- metric: bits_per_byte
metadata:
- version: 2.0
version: 2.0
......@@ -14,4 +14,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
- version: 1.0
version: 1.0
......@@ -16,4 +16,4 @@ metric_list:
aggregation: !function metrics.agg_bleu
higher_is_better: true
metadata:
- version: 0.0
version: 0.0
......@@ -14,4 +14,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
- version: 1.0
version: 1.0
......@@ -11,4 +11,4 @@ doc_to_choice: !function utils.doc_to_choice
metric_list:
- metric: acc
metadata:
- version: 1.0
version: 1.0
import argparse
from typing import Dict, List
import yaml
......
......@@ -16,4 +16,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
- version: 1.0
version: 1.0
......@@ -15,4 +15,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
- version: 1.0
version: 1.0
......@@ -17,4 +17,4 @@ metric_list:
aggregation: mean
higher_is_better: true
metadata:
- version: 1.0
version: 1.0
import os
import re
import sys
import yaml
import collections
import fnmatch
import functools
import gc
import importlib.util
import inspect
import logging
import os
import pathlib
import functools
import re
import subprocess
import collections
import importlib.util
import fnmatch
from typing import Iterator, List, Literal, Union, Any, Callable
import sys
import time
from functools import wraps
from itertools import islice
from typing import (
Any,
Callable,
Iterable,
Iterator,
List,
Literal,
Optional,
Tuple,
Type,
Union,
)
import gc
import torch
import transformers
import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined
from itertools import islice
import logging
logging.basicConfig(
format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
......@@ -143,7 +154,7 @@ class MultiChoice:
def __contains__(self, values) -> bool:
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0:
eval_logger.info(f"Available tasks to choose:")
eval_logger.info("Available tasks to choose:")
for choice in self.choices:
eval_logger.info(f" - {choice}")
raise ValueError("'{}' is not in task list".format(value))
......@@ -157,7 +168,7 @@ class MultiChoice:
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
if type(patterns) == str:
if isinstance(patterns, str):
patterns = [patterns]
task_names = set()
......@@ -332,7 +343,7 @@ class Grouper:
def make_table(result_dict, column: str = "results"):
"""Generate table of results."""
from pytablewriter import MarkdownTableWriter, LatexTableWriter
from pytablewriter import LatexTableWriter, MarkdownTableWriter
if column == "results":
column_name = "Tasks"
......@@ -466,7 +477,7 @@ def import_function(loader, node):
yaml_path = os.path.dirname(loader.name)
*module_name, function_name = function_name.split(".")
if type(module_name) == list:
if isinstance(module_name, list):
module_name = ".".join(module_name)
module_path = os.path.normpath(os.path.join(yaml_path, "{}.py".format(module_name)))
......@@ -496,7 +507,7 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
include_path = yaml_config["include"]
del yaml_config["include"]
if type(include_path) == str:
if isinstance(include_path, str):
include_path = [include_path]
# Load from the last one first
......@@ -671,7 +682,7 @@ def divide(iterable, n) -> List[Iterator]:
"""Divide the elements from *iterable* into *n* parts, maintaining
order.
>>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6])
>>> group_1, group_2 = divide([1, 2, 3, 4, 5, 6], 2)
>>> list(group_1)
[1, 2, 3]
>>> list(group_2)
......@@ -680,14 +691,14 @@ def divide(iterable, n) -> List[Iterator]:
If the length of *iterable* is not evenly divisible by *n*, then the
length of the returned iterables will not be identical:
>>> children = divide(3, [1, 2, 3, 4, 5, 6, 7])
>>> children = divide([1, 2, 3, 4, 5, 6, 7], 3)
>>> [list(c) for c in children]
[[1, 2, 3], [4, 5], [6, 7]]
If the length of the iterable is smaller than n, then the last returned
iterables will be empty:
>>> children = divide(5, [1, 2, 3])
>>> children = divide([1, 2, 3], 5)
>>> [list(c) for c in children]
[[1], [2], [3], [], []]
......@@ -716,3 +727,205 @@ def divide(iterable, n) -> List[Iterator]:
ret.append(iter(seq[start:stop]))
return ret
def retry_on_specific_exceptions(
on_exceptions: List[Type[Exception]],
max_retries: Optional[int] = None,
backoff_time: float = 3.0,
backoff_multiplier: float = 1.5,
on_exception_callback: Optional[Callable[[Exception, float], Any]] = None,
):
"""Retry on an LLM Provider's rate limit error with exponential backoff
For example, to use for OpenAI, do the following:
```
from openai import RateLimitError
# Recommend specifying max_retries to avoid infinite loops!
@retry_on_specific_exceptions([RateLimitError], max_retries=3)
def completion(...):
# Wrap OpenAI completion function here
...
```
"""
def decorator(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
sleep_time = backoff_time
attempt = 0
while max_retries is None or attempt < max_retries:
try:
return func(*args, **kwargs)
except tuple(on_exceptions) as e:
if on_exception_callback is not None:
on_exception_callback(e, sleep_time)
time.sleep(sleep_time)
sleep_time *= backoff_multiplier
attempt += 1
return wrapper
return decorator
class Collator:
"""
A class for reordering and batching elements of an array.
This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data.
"""
def __init__(
self,
arr: List,
sort_fn: Callable,
group_fn: Callable = lambda x: x[1],
grouping: bool = False,
) -> None:
self.grouping = grouping
self.fn = sort_fn
self.group_fn = lambda x: group_fn(x[1]) # first index are enumerated indices
self.reorder_indices: List = []
self.size = len(arr)
self.arr_with_indices: Iterable[Any] = tuple(enumerate(arr)) # [indices, (arr)]
if self.grouping is True:
self.group_by_index()
def group_by_index(self) -> None:
self.arr_with_indices = self.group(
self.arr_with_indices, fn=self.group_fn, values=False
)
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
"""
Generates and yields batches from the reordered array.
Parameters:
- n (int): The size of each batch. Defaults to 1.
- batch_fn (Optional[Callable[[int, Iterable], int]]): A function to determine the size of each batch. Defaults to None.
Yields:
Iterator: An iterator over batches of reordered elements.
"""
if self.grouping:
for (
key,
values,
) in self.arr_with_indices.items(): # type: ignore
values = self._reorder(values)
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
else:
values = self._reorder(self.arr_with_indices) # type: ignore
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> List:
"""
Reorders the elements in the array based on the sorting function.
Parameters:
- arr (Union[List, Tuple[Tuple[int, Any], ...]]): The array or iterable to be reordered.
Yields:
List: Yields reordered elements one by one.
"""
arr = sorted(arr, key=lambda x: self.fn(x[1]))
self.reorder_indices.extend([x[0] for x in arr])
yield from [x[1] for x in arr]
def get_original(self, newarr: List) -> List:
"""
Restores the original order of elements from the reordered list.
Parameters:
- newarr (List): The reordered array.
Returns:
List: The array with elements restored to their original order.
"""
res = [None] * self.size
cov = [False] * self.size
for ind, v in zip(self.reorder_indices, newarr):
res[ind] = v
cov[ind] = True
assert all(cov)
return res
def __len__(self):
return self.size
@staticmethod
def group(arr: Iterable, fn: Callable, values: bool = False) -> Iterable:
"""
Groups elements of an iterable based on a provided function.
Parameters:
- arr (Iterable): The iterable to be grouped.
- fn (Callable): The function to determine the grouping.
- values (bool): If True, returns the values of the group. Defaults to False.
Returns:
Iterable: An iterable of grouped elements.
"""
res = collections.defaultdict(list)
for ob in arr:
try:
hashable_dict = tuple(
(
key,
tuple(value)
if isinstance(value, collections.abc.Iterable)
else value,
)
for key, value in sorted(fn(ob).items())
)
res[hashable_dict].append(ob)
except TypeError:
res[fn(ob)].append(ob)
if not values:
return res
return res.values()
@staticmethod
def get_chunks(_iter, n: int = 0, fn=None):
"""
Divides an iterable into chunks of specified size or based on a given function.
Useful for batching
Parameters:
- iter: The input iterable to be divided into chunks.
- n: An integer representing the size of each chunk. Default is 0.
- fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
Returns:
An iterator that yields chunks of the input iterable.
Example usage:
```
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
for chunk in chunks(data, 3):
print(chunk)
```
Output:
```
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[10]
```
"""
arr = []
_iter = tuple(_iter)
for i, x in enumerate(_iter):
arr.append(x)
if len(arr) == (fn(i, _iter) if fn else n):
yield arr
arr = []
if arr:
yield arr
......@@ -54,35 +54,48 @@ Homepage = "https://github.com/EleutherAI/lm-evaluation-harness"
Repository = "https://github.com/EleutherAI/lm-evaluation-harness"
[project.optional-dependencies]
dev = ["black", "flake8", "pre-commit", "pytest", "pytest-cov"]
linting = [
"flake8",
"pylint",
"mypy",
"pre-commit",
]
testing = ["pytest", "pytest-cov", "pytest-xdist"]
multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"]
anthropic = ["anthropic"]
dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy"]
gptq = ["auto-gptq[triton] @ git+https://github.com/PanQiWei/AutoGPTQ"]
ifeval = ["langdetect", "immutabledict"]
mamba = ["mamba_ssm", "causal-conv1d==1.0.2"]
math = ["sympy>=1.12", "antlr4-python3-runtime==4.11"]
sentencepiece = ["sentencepiece>=0.1.98", "protobuf>=4.22.1"]
multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"]
openai = ["openai==1.3.9", "tiktoken"]
promptsource = [
"promptsource @ git+https://github.com/bigscience-workshop/promptsource.git#egg=promptsource"
]
gptq = ["auto-gptq[triton] @ git+https://github.com/PanQiWei/AutoGPTQ"]
anthropic = ["anthropic"]
openai = ["openai==1.3.9", "tiktoken"]
vllm = ["vllm"]
ifeval = ["langdetect", "immutabledict"]
sentencepiece = ["sentencepiece>=0.1.98", "protobuf>=4.22.1"]
testing = ["pytest", "pytest-cov", "pytest-xdist"]
vllm = ["vllm<=0.2.5"]
zeno = ["pandas", "zeno-client"]
all = [
"lm_eval[anthropic]",
"lm_eval[dev]",
"lm_eval[testing]",
"lm_eval[gptq]",
"lm_eval[ifeval]",
"lm_eval[linting]",
"lm_eval[mamba]",
"lm_eval[math]",
"lm_eval[multilingual]",
"lm_eval[sentencepiece]",
"lm_eval[promptsource]",
"lm_eval[gptq]",
"lm_eval[anthropic]",
"lm_eval[openai]",
"lm_eval[promptsource]",
"lm_eval[sentencepiece]",
"lm_eval[testing]",
"lm_eval[vllm]",
"lm_eval[ifeval]",
"lm_eval[zeno]",
]
[tool.ruff]
extend-exclude = ["lm_eval/evaluator.py", "lm_eval/tasks/*.py"]
[tool.ruff.lint]
extend-select = ["I"]
[tool.ruff.isort]
lines-after-imports = 2
known-first-party = ["lm_eval"]
[tool.ruff.extend-per-file-ignores]
"__init__.py" = ["F401","F402","F403","I"]
"lm_eval/tasks/*"= ["E721"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment