Unverified Commit 9fb2ebab authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Consolidate batching (#1197)



* refactor dataloader

* cleanup + add docs

* change arg

* renamed Collator and added testing

* parametrized test for Collator

* appease pre-commit

* added edge case batch 0 (no batching)

* fix typos

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent b12bb1d4
import copy
import os
from collections import defaultdict
from pathlib import Path
from typing import List, Literal, Optional, Tuple, Union
......@@ -21,7 +20,7 @@ from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.utils import stop_sequences_criteria
from lm_eval.utils import Collator, stop_sequences_criteria
eval_logger = utils.eval_logger
......@@ -632,7 +631,7 @@ class HFLM(LM):
padding_side: str = "left",
left_truncate_len: int = None,
truncation: bool = False,
) -> Tuple[List[int], List[int]]:
) -> Tuple[torch.Tensor, torch.Tensor]:
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side
......@@ -842,6 +841,7 @@ class HFLM(LM):
res = []
def _collate(x):
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
......@@ -852,26 +852,27 @@ class HFLM(LM):
toks = x[1] + x[2]
return -len(toks), tuple(toks)
re_ord = utils.Reorderer(requests, _collate)
re_ord = Collator(requests, sort_fn=_collate)
n_reordered_requests = len(re_ord.get_reordered())
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
chunks = utils.chunks(
re_ord.get_reordered(),
n=self.batch_size
n_reordered_requests = len(re_ord)
batch_size = (
self.batch_size
if self.batch_size != "auto"
else override_bs
if override_bs is not None
else 0,
fn=self._batch_scheduler
else 0
)
batch_fn = (
self._batch_scheduler
if self.batch_size == "auto"
and n_reordered_requests > 0
and not override_bs
else None,
else None
)
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0)))
for chunk in chunks:
inps = []
......@@ -1015,10 +1016,10 @@ class HFLM(LM):
return re_ord.get_original(res)
def generate_until(self, requests: List[Instance]) -> List[str]:
res = defaultdict(list)
re_ords = {}
res = []
def _collate(x):
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
......@@ -1028,14 +1029,6 @@ class HFLM(LM):
toks = self.tok_encode(x[0])
return -len(toks), x[0]
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
pbar = tqdm(total=len(requests), disable=(self.rank != 0))
if self.batch_size == "auto":
# using rolling window with maximum context
......@@ -1044,98 +1037,102 @@ class HFLM(LM):
print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size
# for each different set of kwargs, we execute all requests, by batch.
for key, re_ord in re_ords.items():
chunks = utils.chunks(
re_ord.get_reordered(),
n=self.batch_size
if self.batch_size != "auto"
else adaptive_batch_size
if adaptive_batch_size is not None
else 0,
fn=self._batch_scheduler
if self.batch_size == "auto" and not adaptive_batch_size
else None,
)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk)
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
batch_size = (
self.batch_size
if self.batch_size != "auto"
else adaptive_batch_size
if adaptive_batch_size is not None
else 0
)
batch_fn = (
self._batch_scheduler
if self.batch_size == "auto" and not adaptive_batch_size
else None
)
# set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length
# encode, pad, and truncate contexts for this batch
context_enc, attn_masks = self.tok_batch_encode(
contexts,
left_truncate_len=max_ctx_len,
truncation=self.truncation,
)
context_enc = context_enc.to(self.device)
attn_masks = attn_masks.to(self.device)
if "max_length" not in kwargs:
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
# perform batched generation
cont = self._model_generate(
context=context_enc,
attention_mask=attn_masks,
stop=until,
**kwargs,
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = Collator([reg.args for reg in requests], _collate, grouping=True)
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk)
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length
# encode, pad, and truncate contexts for this batch
context_enc, attn_masks = self.tok_batch_encode(
contexts,
left_truncate_len=max_ctx_len,
truncation=self.truncation,
)
context_enc = context_enc.to(self.device)
attn_masks = attn_masks.to(self.device)
cont_toks_list = cont.tolist()
for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only LM
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
cont_toks = cont_toks[context_enc.shape[1] :]
if "max_length" not in kwargs:
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
s = self.tok_decode(cont_toks)
# perform batched generation
cont = self._model_generate(
context=context_enc,
attention_mask=attn_masks,
stop=until,
**kwargs,
)
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
if len(term) > 0:
# ignore '' separator,
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s = s.split(term)[0]
cont_toks_list = cont.tolist()
for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only LM
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
cont_toks = cont_toks[context_enc.shape[1] :]
res[key].append(s)
s = self.tok_decode(cont_toks)
self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), s
)
pbar.update(1)
# reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key])
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
if len(term) > 0:
# ignore '' separator,
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s = s.split(term)[0]
res.append(s)
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
pbar.update(1)
# reorder this group of results back to original unsorted form
res = re_ords.get_original(res)
pbar.close()
return grouper.get_original(res)
return res
......@@ -13,7 +13,18 @@ import sys
import time
from functools import wraps
from itertools import islice
from typing import Any, Callable, Iterator, List, Literal, Optional, Type, Union
from typing import (
Any,
Callable,
Iterable,
Iterator,
List,
Literal,
Optional,
Tuple,
Type,
Union,
)
import torch
import transformers
......@@ -756,3 +767,157 @@ def retry_on_specific_exceptions(
return wrapper
return decorator
class Collator:
"""
A class for reordering and batching elements of an array.
This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data.
"""
def __init__(
self,
arr: List,
sort_fn: Callable,
group_fn: Callable = lambda x: x[1],
grouping: bool = False,
) -> None:
self.grouping = grouping
self.fn = sort_fn
self.group_fn = lambda x: group_fn(x[1]) # first index are enumerated indices
self.reorder_indices: List = []
self.size = len(arr)
self.arr_with_indices: Iterable[Any] = tuple(enumerate(arr)) # [indices, (arr)]
if self.grouping is True:
self.group_by_index()
def group_by_index(self) -> None:
self.arr_with_indices = self.group(
self.arr_with_indices, fn=self.group_fn, values=False
)
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
"""
Generates and yields batches from the reordered array.
Parameters:
- n (int): The size of each batch. Defaults to 1.
- batch_fn (Optional[Callable[[int, Iterable], int]]): A function to determine the size of each batch. Defaults to None.
Yields:
Iterator: An iterator over batches of reordered elements.
"""
if self.grouping:
for (
key,
values,
) in self.arr_with_indices.items(): # type: ignore
values = self._reorder(values)
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
else:
values = self._reorder(self.arr_with_indices) # type: ignore
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> List:
"""
Reorders the elements in the array based on the sorting function.
Parameters:
- arr (Union[List, Tuple[Tuple[int, Any], ...]]): The array or iterable to be reordered.
Yields:
List: Yields reordered elements one by one.
"""
arr = sorted(arr, key=lambda x: self.fn(x[1]))
self.reorder_indices.extend([x[0] for x in arr])
yield from [x[1] for x in arr]
def get_original(self, newarr: List) -> List:
"""
Restores the original order of elements from the reordered list.
Parameters:
- newarr (List): The reordered array.
Returns:
List: The array with elements restored to their original order.
"""
res = [None] * self.size
cov = [False] * self.size
for ind, v in zip(self.reorder_indices, newarr):
res[ind] = v
cov[ind] = True
assert all(cov)
return res
def __len__(self):
return self.size
def group(self, arr: Iterable, fn: Callable, values: bool = False) -> Iterable:
"""
Groups elements of an iterable based on a provided function.
Parameters:
- arr (Iterable): The iterable to be grouped.
- fn (Callable): The function to determine the grouping.
- values (bool): If True, returns the values of the group. Defaults to False.
Returns:
Iterable: An iterable of grouped elements.
"""
res = collections.defaultdict(list)
for ob in arr:
try:
hashable_dict = tuple(
(key, tuple(value) if isinstance(value, list) else value)
for key, value in sorted(ob[1][1].items())
)
res[hashable_dict].append(ob)
except TypeError:
res[fn(ob)].append(ob)
if not values:
return res
return res.values()
def get_chunks(self, iter, n: int = 0, fn=None):
"""
Divides an iterable into chunks of specified size or based on a given function.
Useful for batching
Parameters:
- iter: The input iterable to be divided into chunks.
- n: An integer representing the size of each chunk. Default is 0.
- fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
Returns:
An iterator that yields chunks of the input iterable.
Example usage:
```
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
for chunk in chunks(data, 3):
print(chunk)
```
Output:
```
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[10]
```
"""
arr = []
for i, x in enumerate(iter):
arr.append(x)
if len(arr) == (fn(i, iter) if fn else n):
yield arr
arr = []
if arr:
yield arr
from lm_eval.utils import get_rolling_token_windows, make_disjoint_window
import pytest
from lm_eval.utils import Collator, get_rolling_token_windows, make_disjoint_window
# noinspection DuplicatedCode
......@@ -220,3 +222,76 @@ def test_make_disjoint_window():
)
assert make_disjoint_window(([1, 2, 3, 4, 5], [4, 5, 6])) == ([1, 2, 3], [4, 5, 6])
assert make_disjoint_window(([1, 2, 3, 4, 5], [6])) == ([1, 2, 3, 4, 5], [6])
class TestCollator:
def make_generate_sample(self, end=10):
strings = ["x" * i for i in range(1, end + 1)]
gen_kwargs1, gen_kwargs2 = (
{"temperature": 0},
{"temperature": 0, "until": ["nn", "\n\n"]},
)
args = [
(string, gen_kwargs1 if i < len(strings) // 2 else gen_kwargs2)
for i, string in enumerate(strings)
]
return args
def make_loglikelihood_sample(self, end=11):
samples = [
(("x", "x"), list(range(1, total_length + 1)))
for total_length in range(1, end + 1)
]
return samples
@pytest.mark.parametrize("batch_size, end", [(17, 30), (8, 61), (12, 48), (0, 9)])
def test_generations(self, batch_size, end):
_collate_gen = lambda x: (-len(x[0]), x[0]) # noqa: E731
generation_samples = self.make_generate_sample(int(end))
gens = Collator(generation_samples, _collate_gen, grouping=True)
chunks = gens.get_batched(n=int(batch_size), batch_fn=None)
output = []
for chunks in chunks:
# check batching
group_one = end // 2
group_two = end - end // 2
assert (
len(chunks) <= batch_size
if batch_size != 0
else len(chunks) in [group_one, group_two]
)
# check if reorder-er is working correctly
assert all(
len(chunks[i][0]) <= len(chunks[i - 1][0])
for i in range(1, len(chunks))
)
# check if grouping correctly
assert all(x[1] == chunks[0][1] for x in chunks)
for x in chunks:
output.append(x)
reordered_output = gens.get_original(output)
# check get original
assert reordered_output == generation_samples
@pytest.mark.parametrize("batch_size, end", [(17, 30), (8, 61), (12, 48), (0, 3)])
def test_loglikelihood(self, batch_size, end):
_collate_log = lambda x: (-len(x[1]), tuple(x[1])) # noqa: E731
loglikelihood_samples = self.make_loglikelihood_sample(int(end))
loglikelihoods = Collator(loglikelihood_samples, _collate_log, grouping=False)
chunks = loglikelihoods.get_batched(n=int(batch_size), batch_fn=None)
output = []
for chunks in chunks:
# check batching
assert len(chunks) <= batch_size if batch_size != 0 else len(chunks) == end
# check reorder
assert all(
len(chunks[i][1]) <= len(chunks[i - 1][1])
for i in range(1, len(chunks))
)
for x in chunks:
output.append(x[1])
# check indices
reordered_output = loglikelihoods.get_original(output)
assert reordered_output == [x[1] for x in loglikelihood_samples]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment