"...lm-evaluation-harness.git" did not exist on "82d57f655caa94d7175fa6ebda2fa6ddd8368bfa"
Unverified Commit 9fb2ebab authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Consolidate batching (#1197)



* refactor dataloader

* cleanup + add docs

* change arg

* renamed Collator and added testing

* parametrized test for Collator

* appease pre-commit

* added edge case batch 0 (no batching)

* fix typos

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent b12bb1d4
import copy import copy
import os import os
from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import List, Literal, Optional, Tuple, Union from typing import List, Literal, Optional, Tuple, Union
...@@ -21,7 +20,7 @@ from lm_eval import utils ...@@ -21,7 +20,7 @@ from lm_eval import utils
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import stop_sequences_criteria from lm_eval.utils import Collator, stop_sequences_criteria
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
...@@ -632,7 +631,7 @@ class HFLM(LM): ...@@ -632,7 +631,7 @@ class HFLM(LM):
padding_side: str = "left", padding_side: str = "left",
left_truncate_len: int = None, left_truncate_len: int = None,
truncation: bool = False, truncation: bool = False,
) -> Tuple[List[int], List[int]]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side self.tokenizer.padding_side = padding_side
...@@ -842,6 +841,7 @@ class HFLM(LM): ...@@ -842,6 +841,7 @@ class HFLM(LM):
res = [] res = []
def _collate(x): def _collate(x):
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages: # the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning # - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch # - to know the size of a batch when going through the list, you know the first one is always the batch
...@@ -852,26 +852,27 @@ class HFLM(LM): ...@@ -852,26 +852,27 @@ class HFLM(LM):
toks = x[1] + x[2] toks = x[1] + x[2]
return -len(toks), tuple(toks) return -len(toks), tuple(toks)
re_ord = utils.Reorderer(requests, _collate) re_ord = Collator(requests, sort_fn=_collate)
n_reordered_requests = len(re_ord.get_reordered())
# automatic (variable) batch size detection for vectorization # automatic (variable) batch size detection for vectorization
# pull longest context sample from request # pull longest context sample from request
n_reordered_requests = len(re_ord)
chunks = utils.chunks( batch_size = (
re_ord.get_reordered(), self.batch_size
n=self.batch_size
if self.batch_size != "auto" if self.batch_size != "auto"
else override_bs else override_bs
if override_bs is not None if override_bs is not None
else 0, else 0
fn=self._batch_scheduler )
batch_fn = (
self._batch_scheduler
if self.batch_size == "auto" if self.batch_size == "auto"
and n_reordered_requests > 0 and n_reordered_requests > 0
and not override_bs and not override_bs
else None, else None
) )
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0))) pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0)))
for chunk in chunks: for chunk in chunks:
inps = [] inps = []
...@@ -1015,10 +1016,10 @@ class HFLM(LM): ...@@ -1015,10 +1016,10 @@ class HFLM(LM):
return re_ord.get_original(res) return re_ord.get_original(res)
def generate_until(self, requests: List[Instance]) -> List[str]: def generate_until(self, requests: List[Instance]) -> List[str]:
res = defaultdict(list) res = []
re_ords = {}
def _collate(x): def _collate(x):
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages: # the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning # - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch # - to know the size of a batch when going through the list, you know the first one is always the batch
...@@ -1028,14 +1029,6 @@ class HFLM(LM): ...@@ -1028,14 +1029,6 @@ class HFLM(LM):
toks = self.tok_encode(x[0]) toks = self.tok_encode(x[0])
return -len(toks), x[0] return -len(toks), x[0]
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(total=len(requests), disable=(self.rank != 0))
if self.batch_size == "auto": if self.batch_size == "auto":
# using rolling window with maximum context # using rolling window with maximum context
...@@ -1044,18 +1037,24 @@ class HFLM(LM): ...@@ -1044,18 +1037,24 @@ class HFLM(LM):
print(f"Determined Largest batch size: {batch_size}") print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size adaptive_batch_size = batch_size
# for each different set of kwargs, we execute all requests, by batch. # for each different set of kwargs, we execute all requests, by batch.
for key, re_ord in re_ords.items(): batch_size = (
chunks = utils.chunks( self.batch_size
re_ord.get_reordered(),
n=self.batch_size
if self.batch_size != "auto" if self.batch_size != "auto"
else adaptive_batch_size else adaptive_batch_size
if adaptive_batch_size is not None if adaptive_batch_size is not None
else 0, else 0
fn=self._batch_scheduler )
batch_fn = (
self._batch_scheduler
if self.batch_size == "auto" and not adaptive_batch_size if self.batch_size == "auto" and not adaptive_batch_size
else None, else None
) )
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = Collator([reg.args for reg in requests], _collate, grouping=True)
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
for chunk in chunks: for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk) contexts, all_gen_kwargs = zip(*chunk)
# we assume all gen kwargs in the batch are the same # we assume all gen kwargs in the batch are the same
...@@ -1127,15 +1126,13 @@ class HFLM(LM): ...@@ -1127,15 +1126,13 @@ class HFLM(LM):
# for seq2seq case where self.tok_decode(self.eot_token_id) = '' # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s = s.split(term)[0] s = s.split(term)[0]
res[key].append(s) res.append(s)
self.cache_hook.add_partial( self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
"generate_until", (context, gen_kwargs), s
)
pbar.update(1) pbar.update(1)
# reorder this group of results back to original unsorted form # reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key]) res = re_ords.get_original(res)
pbar.close() pbar.close()
return grouper.get_original(res) return res
...@@ -13,7 +13,18 @@ import sys ...@@ -13,7 +13,18 @@ import sys
import time import time
from functools import wraps from functools import wraps
from itertools import islice from itertools import islice
from typing import Any, Callable, Iterator, List, Literal, Optional, Type, Union from typing import (
Any,
Callable,
Iterable,
Iterator,
List,
Literal,
Optional,
Tuple,
Type,
Union,
)
import torch import torch
import transformers import transformers
...@@ -756,3 +767,157 @@ def retry_on_specific_exceptions( ...@@ -756,3 +767,157 @@ def retry_on_specific_exceptions(
return wrapper return wrapper
return decorator return decorator
class Collator:
"""
A class for reordering and batching elements of an array.
This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data.
"""
def __init__(
self,
arr: List,
sort_fn: Callable,
group_fn: Callable = lambda x: x[1],
grouping: bool = False,
) -> None:
self.grouping = grouping
self.fn = sort_fn
self.group_fn = lambda x: group_fn(x[1]) # first index are enumerated indices
self.reorder_indices: List = []
self.size = len(arr)
self.arr_with_indices: Iterable[Any] = tuple(enumerate(arr)) # [indices, (arr)]
if self.grouping is True:
self.group_by_index()
def group_by_index(self) -> None:
self.arr_with_indices = self.group(
self.arr_with_indices, fn=self.group_fn, values=False
)
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
"""
Generates and yields batches from the reordered array.
Parameters:
- n (int): The size of each batch. Defaults to 1.
- batch_fn (Optional[Callable[[int, Iterable], int]]): A function to determine the size of each batch. Defaults to None.
Yields:
Iterator: An iterator over batches of reordered elements.
"""
if self.grouping:
for (
key,
values,
) in self.arr_with_indices.items(): # type: ignore
values = self._reorder(values)
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
else:
values = self._reorder(self.arr_with_indices) # type: ignore
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> List:
"""
Reorders the elements in the array based on the sorting function.
Parameters:
- arr (Union[List, Tuple[Tuple[int, Any], ...]]): The array or iterable to be reordered.
Yields:
List: Yields reordered elements one by one.
"""
arr = sorted(arr, key=lambda x: self.fn(x[1]))
self.reorder_indices.extend([x[0] for x in arr])
yield from [x[1] for x in arr]
def get_original(self, newarr: List) -> List:
"""
Restores the original order of elements from the reordered list.
Parameters:
- newarr (List): The reordered array.
Returns:
List: The array with elements restored to their original order.
"""
res = [None] * self.size
cov = [False] * self.size
for ind, v in zip(self.reorder_indices, newarr):
res[ind] = v
cov[ind] = True
assert all(cov)
return res
def __len__(self):
return self.size
def group(self, arr: Iterable, fn: Callable, values: bool = False) -> Iterable:
"""
Groups elements of an iterable based on a provided function.
Parameters:
- arr (Iterable): The iterable to be grouped.
- fn (Callable): The function to determine the grouping.
- values (bool): If True, returns the values of the group. Defaults to False.
Returns:
Iterable: An iterable of grouped elements.
"""
res = collections.defaultdict(list)
for ob in arr:
try:
hashable_dict = tuple(
(key, tuple(value) if isinstance(value, list) else value)
for key, value in sorted(ob[1][1].items())
)
res[hashable_dict].append(ob)
except TypeError:
res[fn(ob)].append(ob)
if not values:
return res
return res.values()
def get_chunks(self, iter, n: int = 0, fn=None):
"""
Divides an iterable into chunks of specified size or based on a given function.
Useful for batching
Parameters:
- iter: The input iterable to be divided into chunks.
- n: An integer representing the size of each chunk. Default is 0.
- fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
Returns:
An iterator that yields chunks of the input iterable.
Example usage:
```
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
for chunk in chunks(data, 3):
print(chunk)
```
Output:
```
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[10]
```
"""
arr = []
for i, x in enumerate(iter):
arr.append(x)
if len(arr) == (fn(i, iter) if fn else n):
yield arr
arr = []
if arr:
yield arr
from lm_eval.utils import get_rolling_token_windows, make_disjoint_window import pytest
from lm_eval.utils import Collator, get_rolling_token_windows, make_disjoint_window
# noinspection DuplicatedCode # noinspection DuplicatedCode
...@@ -220,3 +222,76 @@ def test_make_disjoint_window(): ...@@ -220,3 +222,76 @@ def test_make_disjoint_window():
) )
assert make_disjoint_window(([1, 2, 3, 4, 5], [4, 5, 6])) == ([1, 2, 3], [4, 5, 6]) assert make_disjoint_window(([1, 2, 3, 4, 5], [4, 5, 6])) == ([1, 2, 3], [4, 5, 6])
assert make_disjoint_window(([1, 2, 3, 4, 5], [6])) == ([1, 2, 3, 4, 5], [6]) assert make_disjoint_window(([1, 2, 3, 4, 5], [6])) == ([1, 2, 3, 4, 5], [6])
class TestCollator:
def make_generate_sample(self, end=10):
strings = ["x" * i for i in range(1, end + 1)]
gen_kwargs1, gen_kwargs2 = (
{"temperature": 0},
{"temperature": 0, "until": ["nn", "\n\n"]},
)
args = [
(string, gen_kwargs1 if i < len(strings) // 2 else gen_kwargs2)
for i, string in enumerate(strings)
]
return args
def make_loglikelihood_sample(self, end=11):
samples = [
(("x", "x"), list(range(1, total_length + 1)))
for total_length in range(1, end + 1)
]
return samples
@pytest.mark.parametrize("batch_size, end", [(17, 30), (8, 61), (12, 48), (0, 9)])
def test_generations(self, batch_size, end):
_collate_gen = lambda x: (-len(x[0]), x[0]) # noqa: E731
generation_samples = self.make_generate_sample(int(end))
gens = Collator(generation_samples, _collate_gen, grouping=True)
chunks = gens.get_batched(n=int(batch_size), batch_fn=None)
output = []
for chunks in chunks:
# check batching
group_one = end // 2
group_two = end - end // 2
assert (
len(chunks) <= batch_size
if batch_size != 0
else len(chunks) in [group_one, group_two]
)
# check if reorder-er is working correctly
assert all(
len(chunks[i][0]) <= len(chunks[i - 1][0])
for i in range(1, len(chunks))
)
# check if grouping correctly
assert all(x[1] == chunks[0][1] for x in chunks)
for x in chunks:
output.append(x)
reordered_output = gens.get_original(output)
# check get original
assert reordered_output == generation_samples
@pytest.mark.parametrize("batch_size, end", [(17, 30), (8, 61), (12, 48), (0, 3)])
def test_loglikelihood(self, batch_size, end):
_collate_log = lambda x: (-len(x[1]), tuple(x[1])) # noqa: E731
loglikelihood_samples = self.make_loglikelihood_sample(int(end))
loglikelihoods = Collator(loglikelihood_samples, _collate_log, grouping=False)
chunks = loglikelihoods.get_batched(n=int(batch_size), batch_fn=None)
output = []
for chunks in chunks:
# check batching
assert len(chunks) <= batch_size if batch_size != 0 else len(chunks) == end
# check reorder
assert all(
len(chunks[i][1]) <= len(chunks[i - 1][1])
for i in range(1, len(chunks))
)
for x in chunks:
output.append(x[1])
# check indices
reordered_output = loglikelihoods.get_original(output)
assert reordered_output == [x[1] for x in loglikelihood_samples]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment