Unverified Commit 45941c67 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Group reqs by context (#1425)



* add key lookup for same contexts

* nit

* appease pre-commit

* nit

* use `expand` (in-place view) rather than `repeat`

* try mixed grouping

* add docs.

* nit

* nit

* nits

* fix tests

* Move greedy_tokens calculation out of cache loop

* nit

* nits

* add test

* nits

* fix name conflict

* fix name conflict

* chunk tensor

* move Collator

* nits/docstring

* fixup

* fixup

* group contexts only for decoders

* pre-commit

* fix `generate_until` test

* fix `generate_until` test

* Update lm_eval/models/huggingface.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* add docs

* nit

* add docs

* add docs

* add 'logits_cache' arg

* bugfix

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 8680e938
...@@ -4,11 +4,11 @@ import random ...@@ -4,11 +4,11 @@ import random
from collections.abc import Iterable from collections.abc import Iterable
from typing import List from typing import List
import evaluate as hf_evaluate
import numpy as np import numpy as np
import sacrebleu import sacrebleu
import sklearn.metrics import sklearn.metrics
import evaluate
from lm_eval.api.registry import register_aggregation, register_metric from lm_eval.api.registry import register_aggregation, register_metric
...@@ -146,7 +146,7 @@ def acc_mutual_info_fn(items): # This is a passthrough function ...@@ -146,7 +146,7 @@ def acc_mutual_info_fn(items): # This is a passthrough function
return items return items
exact_match = evaluate.load("exact_match") exact_match = hf_evaluate.load("exact_match")
@register_metric( @register_metric(
......
import logging import logging
from typing import Callable, Dict from typing import Callable, Dict
import evaluate import evaluate as hf_evaluate
from lm_eval.api.model import LM from lm_eval.api.model import LM
...@@ -128,7 +129,7 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable: ...@@ -128,7 +129,7 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
) )
try: try:
metric_object = evaluate.load(name) metric_object = hf_evaluate.load(name)
return metric_object.compute return metric_object.compute
except Exception: except Exception:
eval_logger.error( eval_logger.error(
......
...@@ -78,9 +78,8 @@ class HFLM(LM): ...@@ -78,9 +78,8 @@ class HFLM(LM):
def __init__( def __init__(
self, self,
pretrained: Optional[Union[str, transformers.PreTrainedModel]] = "gpt2", pretrained: Optional[Union[str, transformers.PreTrainedModel]] = "gpt2",
backend: Optional[ backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
Literal["default", "causal", "seq2seq"] # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
] = "default", # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
revision: Optional[str] = "main", revision: Optional[str] = "main",
subfolder: Optional[str] = None, subfolder: Optional[str] = None,
tokenizer: Optional[ tokenizer: Optional[
...@@ -91,6 +90,7 @@ class HFLM(LM): ...@@ -91,6 +90,7 @@ class HFLM(LM):
] ]
] = None, ] = None,
truncation: Optional[bool] = False, truncation: Optional[bool] = False,
logits_cache: bool = True,
max_length: Optional[int] = None, max_length: Optional[int] = None,
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto", dtype: Optional[Union[str, torch.dtype]] = "auto",
...@@ -239,7 +239,7 @@ class HFLM(LM): ...@@ -239,7 +239,7 @@ class HFLM(LM):
) )
self.truncation = truncation self.truncation = truncation
self.logits_cache = logits_cache
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
# select (or create) a pad token to use # select (or create) a pad token to use
if self.tokenizer.pad_token: if self.tokenizer.pad_token:
...@@ -760,7 +760,9 @@ class HFLM(LM): ...@@ -760,7 +760,9 @@ class HFLM(LM):
**generation_kwargs, **generation_kwargs,
) )
def _select_cont_toks(self, logits, contlen=None, inplen=None): def _select_cont_toks(
self, logits: torch.Tensor, contlen: int = None, inplen: int = None
) -> torch.Tensor:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
assert ( assert (
contlen and inplen contlen and inplen
...@@ -809,7 +811,7 @@ class HFLM(LM): ...@@ -809,7 +811,7 @@ class HFLM(LM):
new_reqs.append(((context, continuation), context_enc, continuation_enc)) new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs) return self._loglikelihood_tokens(requests=new_reqs)
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
loglikelihoods = [] loglikelihoods = []
...@@ -851,7 +853,7 @@ class HFLM(LM): ...@@ -851,7 +853,7 @@ class HFLM(LM):
rolling_token_windows += pad_amnt * [rolling_token_windows[0]] rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
string_nll = self._loglikelihood_tokens( string_nll = self._loglikelihood_tokens(
rolling_token_windows, requests=rolling_token_windows,
disable_tqdm=True, disable_tqdm=True,
override_bs=adaptive_batch_size, override_bs=adaptive_batch_size,
) )
...@@ -893,7 +895,7 @@ class HFLM(LM): ...@@ -893,7 +895,7 @@ class HFLM(LM):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = [] res = []
def _collate(x): def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key for the sorted method""" """Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages: # the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning # - time estimates will always be over not underestimates, which is more useful for planning
...@@ -902,10 +904,26 @@ class HFLM(LM): ...@@ -902,10 +904,26 @@ class HFLM(LM):
# automatic adaptive batches much much easier to implement # automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end # - any OOMs will happen right away rather than near the end
toks = x[1] + x[2] toks = req[1] + req[2]
return -len(toks), tuple(toks) return -len(toks), tuple(toks)
re_ord = Collator(requests, sort_fn=_collate) def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key to group and lookup one-token continuations"""
# Use with group_by="contexts" (optional)"
# allows for the creation of a lookup, so we can re-use logits in case of one-token continuations.
# speeds up some multiple-choice tasks proportionally to the number of choices.
# groups requests by context+continuation[:-1] and infer on one request/group.
return req[-2] + req[-1][:-1]
re_ord = Collator(
requests,
sort_fn=_collate,
group_by="contexts"
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
and self.logits_cache
else None,
group_fn=_lookup_one_token_cont,
)
# automatic (variable) batch size detection for vectorization # automatic (variable) batch size detection for vectorization
# pull longest context sample from request # pull longest context sample from request
...@@ -1026,7 +1044,7 @@ class HFLM(LM): ...@@ -1026,7 +1044,7 @@ class HFLM(LM):
self._model_call(batched_inps, **call_kwargs), dim=-1 self._model_call(batched_inps, **call_kwargs), dim=-1
) # [batch, padding_length (inp or cont), vocab] ) # [batch, padding_length (inp or cont), vocab]
for (cache_key, _, _), logits, inplen, cont_toks in zip( for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
chunk, multi_logits, inplens, cont_toks_list chunk, multi_logits, inplens, cont_toks_list
): ):
# Slice to original seq length # Slice to original seq length
...@@ -1045,24 +1063,36 @@ class HFLM(LM): ...@@ -1045,24 +1063,36 @@ class HFLM(LM):
# Check if per-token argmax is exactly equal to continuation # Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1) greedy_tokens = logits.argmax(dim=-1)
cont_toks = torch.tensor(
cont_toks, dtype=torch.long, device=self.device
).unsqueeze(0) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
-1
) # [1, seq]
# Answer: (log prob, is-exact-match) # check for one-token continuation cache hits.
answer = (float(logits.sum()), bool(max_equal)) # noop in case group_by != "contexts" or no cache hit and returns the
# original args. Otherwise, expands the logits batch dimension and yields each
res.append(answer) # batch along with matching continuation tokens and prompt strings.
# logits -> [1, seq, vocab]
self.cache_hook.add_partial("loglikelihood", cache_key, answer) for request_str, cont_toks, logits in re_ord.get_cache(
pbar.update(1) req_str=request_str,
cxt_toks=ctx_tokens,
cont_toks=cont_toks,
logits=logits,
):
cont_toks = torch.tensor(
cont_toks, dtype=torch.long, device=self.device
).unsqueeze(0) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
-1
) # [1, seq]
# Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal))
res.append(answer)
self.cache_hook.add_partial("loglikelihood", request_str, answer)
pbar.update(1)
pbar.close() pbar.close()
...@@ -1071,7 +1101,7 @@ class HFLM(LM): ...@@ -1071,7 +1101,7 @@ class HFLM(LM):
def generate_until(self, requests: List[Instance]) -> List[str]: def generate_until(self, requests: List[Instance]) -> List[str]:
res = [] res = []
def _collate(x): def _collate(req: Tuple[str, dict]):
"""Defines the key for the sorted method""" """Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages: # the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning # - time estimates will always be over not underestimates, which is more useful for planning
...@@ -1079,8 +1109,8 @@ class HFLM(LM): ...@@ -1079,8 +1109,8 @@ class HFLM(LM):
# padded context length. this is useful to simplify the batching logic and more importantly to make # padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement # automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end # - any OOMs will happen right away rather than near the end
toks = self.tok_encode(x[0]) toks = self.tok_encode(req[0])
return -len(toks), x[0] return -len(toks), req[0]
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(total=len(requests), disable=(self.rank != 0))
adaptive_batch_size = None adaptive_batch_size = None
...@@ -1107,7 +1137,13 @@ class HFLM(LM): ...@@ -1107,7 +1137,13 @@ class HFLM(LM):
# we group requests by their generation_kwargs, # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch. # in the same batch.
re_ords = Collator([reg.args for reg in requests], _collate, grouping=True) # group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
re_ords = Collator(
[reg.args for reg in requests],
sort_fn=_collate,
group_by="gen_kwargs",
group_fn=lambda x: x[1],
)
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
for chunk in chunks: for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk) contexts, all_gen_kwargs = zip(*chunk)
......
...@@ -6,6 +6,7 @@ from functools import wraps ...@@ -6,6 +6,7 @@ from functools import wraps
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Dict,
Iterable, Iterable,
Iterator, Iterator,
List, List,
...@@ -357,65 +358,164 @@ class Collator: ...@@ -357,65 +358,164 @@ class Collator:
A class for reordering and batching elements of an array. A class for reordering and batching elements of an array.
This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data. This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data.
Objects of this class have the group_by attribute which determines the method for grouping
the data while batching it. Three options include "gen_kwargs", "contexts", or None:
If group_by == "gen_kwargs" then requests will be grouped by gen_kwargs
If group_by == "contexts" then requests will be grouped by context + cont[:-1]
If None then requests will just be reordered by length descending.
""" """
def __init__( def __init__(
self, self,
arr: List, arr: List,
sort_fn: Callable, sort_fn: Callable = lambda x: x,
group_fn: Callable = lambda x: x[1], group_fn: Callable = lambda x: x[1],
grouping: bool = False, group_by: Union[Literal["gen_kwargs", "contexts"], None] = None,
) -> None: ) -> None:
self.grouping = grouping self._group_by = group_by
self.fn = sort_fn # 0 indices are enumerated indices. Apply functions to original arr.
self.group_fn = lambda x: group_fn(x[1]) # first index are enumerated indices self._sort_fn = lambda x: sort_fn(x[1])
self.reorder_indices: List = [] self._group_fn = lambda x: group_fn(x[1])
self.size = len(arr) self._reorder_indices: List = []
self.arr_with_indices: Iterable[Any] = tuple(enumerate(arr)) # [indices, (arr)] self._size = len(arr)
if self.grouping is True: self._arr_with_indices: Union[Dict, Tuple[Tuple[int, Any], ...]] = tuple(
self.group_by_index() enumerate(arr)
) # [indices, (arr)]
if self._group_by == "contexts":
self._group_by_context()
elif self._group_by == "gen_kwargs":
self._group_by_index()
def _group_by_index(self) -> None:
"""Group the elements of a list based on their indices."""
self._arr_with_indices = self.group(
self._arr_with_indices, fn=self._group_fn, group_by="gen_kwargs"
)
def group_by_index(self) -> None: def _group_by_context(self) -> None:
self.arr_with_indices = self.group( """Group the array with indices by context."""
self.arr_with_indices, fn=self.group_fn, values=False self._arr_with_indices = self.group(
self._arr_with_indices, fn=self._group_fn, group_by="contexts"
) )
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator: def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
""" """
Generates and yields batches from the reordered array. Generates and yields batches from the reordered array. The method of grouping and batching
depends on the parameter `group_by`.
If `group_by` is set to "gen_kwargs", it will batch the
re-ordered values with same gen_kwargs for each batch.
If `group_by` is "contexts", it caches the requests by context before batching.
If `group_by` is neither "gen_kwargs" nor "contexts", it yields the reordered array
Parameters: Parameters:
- n (int): The size of each batch. Defaults to 1. - n (int): The size of each batch. Defaults to 1.
- batch_fn (Optional[Callable[[int, Iterable], int]]): A function to determine the size of each batch. Defaults to None. - batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of
each batch. Optional, defaults to None.
Returns:
Iterator: An iterator over batches of reordered elements grouped as per the `group_by`
attribute.
Yields: Yields:
Iterator: An iterator over batches of reordered elements. List of batched elements according to the `group_by` attribute.
""" """
if self.grouping: if self._group_by == "gen_kwargs":
for ( for (
key, key,
values, values,
) in self.arr_with_indices.items(): # type: ignore ) in self._arr_with_indices.items(): # type: ignore
values = self._reorder(values) values = self._reorder(values)
batch = self.get_chunks(values, n=n, fn=batch_fn) batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch yield from batch
elif self._group_by == "contexts":
# Get one sample from each key
values = self._reorder(
[value[0] for value in self._arr_with_indices.values()]
)
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
else: else:
values = self._reorder(self.arr_with_indices) # type: ignore values = self._reorder(self._arr_with_indices) # type: ignore
batch = self.get_chunks(values, n=n, fn=batch_fn) batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch yield from batch
def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> List: def get_cache(
self,
req_str: Tuple[str, str] = None,
cxt_toks: List[int] = None,
cont_toks: List[int] = None,
logits: torch.Tensor = None,
) -> Iterator[Tuple[Tuple[str, str], List[int], torch.Tensor]]:
"""
Retrieves cached single-token continuations and their associated arguments, updating indices as necessary.
The behavior of this function varies depending on how the `group_by` attribute is set:
- When `group_by` is "contexts":
The function identifies single-token continuations by checking for keys that equate to
[context+continuation][-1] and logs the indices for re-ordering.
In this mode, this function can work in two scenarios:
1. Cache Hit - Single Match:
If a single matching context-continuation pair is found in the cache,
the function yields the original arguments.
2. Cache Hit - Multiple Matches:
If multiple matching context-continuation pairs are found in the cache,
the function expands the logits batch dimension to match the number of cache hits.
It updates the original requests and continuation tokens.
- When `group_by` is not set to "contexts":
This method yields the original arguments, logits and continuation tokens,
without checking for one-token continuations.
Parameters:
- req_str (tuple[str, str]): Original strings used for CachingLM.
- cxt_toks (list[int]): Full context tokens used for lookup.
- cont_toks (list[int]): Continuation tokens for which logits were generated.
- logits (torch.Tensor [1, seq_length, vocab_size]): Logits generated by the model given context and continuation keys.
Yields:
- Iterator:
- req_str (tuple[str, str]): strings used for CachingLM.
- cont_toks (list[int]) : continuation tokens.
- logits (torch.Tensor [1, seq_length, vocab_size]): The original logits (repeated cache hit times)
"""
if self._group_by == "contexts":
cache_hit: List[
Tuple[int, Tuple[Tuple[str, str], List[int], List[int]]]
] = self._arr_with_indices.pop(tuple(cxt_toks + cont_toks[:-1]))
if (cache_size := len(cache_hit)) == 1:
self._reorder_indices.extend(x[0] for x in cache_hit)
yield req_str, cont_toks, logits
else:
# If we have matching requests then expand the batch dimension (no-op) and
# yield each along with its corresponding args.
multilogits = logits.expand(cache_size, -1, -1).chunk(cache_size)
indices, req_str, cont_toks = zip(
*[(x[0], x[1][0], x[-1][-1]) for x in cache_hit]
)
self._reorder_indices.extend(indices)
for c_key, cont_tok, logit in zip(req_str, cont_toks, multilogits):
yield c_key, cont_tok, logit
else:
yield req_str, cont_toks, logits
def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> Iterator:
""" """
Reorders the elements in the array based on the sorting function. Reorders the elements in the array based on the sorting function.
Parameters: Parameters:
- arr (Union[List, Tuple[Tuple[int, Any], ...]]): The array or iterable to be reordered. - arr (list | tuple[tuple[int, Any], ...]]): The array or iterable to be reordered.
Yields: Yields:
List: Yields reordered elements one by one. Iterator
""" """
arr = sorted(arr, key=lambda x: self.fn(x[1])) arr = sorted(arr, key=self._sort_fn)
self.reorder_indices.extend([x[0] for x in arr]) if not self._group_by == "contexts":
# If grouped by contexts then indices will be set in get_cache()
self._reorder_indices.extend([x[0] for x in arr])
yield from [x[1] for x in arr] yield from [x[1] for x in arr]
def get_original(self, newarr: List) -> List: def get_original(self, newarr: List) -> List:
...@@ -423,15 +523,15 @@ class Collator: ...@@ -423,15 +523,15 @@ class Collator:
Restores the original order of elements from the reordered list. Restores the original order of elements from the reordered list.
Parameters: Parameters:
- newarr (List): The reordered array. - newarr (list): The reordered array.
Returns: Returns:
List: The array with elements restored to their original order. list: The array with elements restored to their original order.
""" """
res = [None] * self.size res = [None] * self._size
cov = [False] * self.size cov = [False] * self._size
for ind, v in zip(self.reorder_indices, newarr): for ind, v in zip(self._reorder_indices, newarr):
res[ind] = v res[ind] = v
cov[ind] = True cov[ind] = True
...@@ -440,39 +540,50 @@ class Collator: ...@@ -440,39 +540,50 @@ class Collator:
return res return res
def __len__(self): def __len__(self):
return self.size return self._size
@staticmethod @staticmethod
def group(arr: Iterable, fn: Callable, values: bool = False) -> Iterable: def group(
arr: Iterable,
fn: Callable,
group_by: Literal["gen_kwargs", "contexts"] = "gen_kwargs",
) -> dict:
""" """
Groups elements of an iterable based on a provided function. Groups elements of an iterable based on a provided function.
The `group_by` parameter determines the method of grouping.
If `group_by` is "contexts", the elements are grouped by [context + cont][:-1].
If `group_by` is "gen_kwargs", the elements are grouped based on the gen_kwargs dict.
Parameters: Parameters:
- arr (Iterable): The iterable to be grouped. - arr (Iterable): The iterable to be grouped.
- fn (Callable): The function to determine the grouping. - fn (Callable): The function to determine the grouping.
- values (bool): If True, returns the values of the group. Defaults to False. - values (bool): If True, returns the values of the group. Defaults to False.
Returns: Returns:
Iterable: An iterable of grouped elements. Iterator: An iterable of grouped elements.
""" """
res = collections.defaultdict(list) res = collections.defaultdict(list)
for ob in arr: for ob in arr:
try: # where ob == [context + cont]
hashable_dict = tuple( if group_by == "contexts":
( res[tuple(fn(ob))].append(ob)
key, else:
tuple(value) try:
if isinstance(value, collections.abc.Iterable) hashable_dict = tuple(
else value, (
key,
tuple(value)
if isinstance(value, collections.abc.Iterable)
else value,
)
for key, value in sorted(fn(ob).items())
) )
for key, value in sorted(fn(ob).items()) res[hashable_dict].append(ob)
) except (TypeError, AttributeError):
res[hashable_dict].append(ob) res[tuple(fn(ob))].append(ob)
except TypeError: return res
res[fn(ob)].append(ob)
if not values:
return res
return res.values()
@staticmethod @staticmethod
def get_chunks(_iter, n: int = 0, fn=None): def get_chunks(_iter, n: int = 0, fn=None):
......
...@@ -276,7 +276,7 @@ class VLLM(LM): ...@@ -276,7 +276,7 @@ class VLLM(LM):
# we group requests by their generation_kwargs, # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch. # in the same batch.
re_ords = Collator(requests, _collate_gen, grouping=True) re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs")
chunks = re_ords.get_batched( chunks = re_ords.get_batched(
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
) )
......
...@@ -16,4 +16,3 @@ filter_list: ...@@ -16,4 +16,3 @@ filter_list:
- function: "regex" - function: "regex"
regex_pattern: "((?<=The answer is )(.*)(?=.)|(?<=the answer is )(.*)(?=.)|(?<=The answer: )(.*)(?=.)|(?<=The final answer: )(.*)(?=.))" regex_pattern: "((?<=The answer is )(.*)(?=.)|(?<=the answer is )(.*)(?=.)|(?<=The answer: )(.*)(?=.)|(?<=The final answer: )(.*)(?=.))"
- function: "take_first" - function: "take_first"
...@@ -15,4 +15,3 @@ filter_list: ...@@ -15,4 +15,3 @@ filter_list:
- function: "regex" - function: "regex"
regex_pattern: "((?<=The answer is )(.*)(?=.)|(?<=the answer is )(.*)(?=.)|(?<=The answer: )(.*)(?=.)|(?<=The final answer: )(.*)(?=.))" regex_pattern: "((?<=The answer is )(.*)(?=.)|(?<=the answer is )(.*)(?=.)|(?<=The answer: )(.*)(?=.)|(?<=The final answer: )(.*)(?=.))"
- function: "take_first" - function: "take_first"
...@@ -74,7 +74,7 @@ class Test_HFLM: ...@@ -74,7 +74,7 @@ class Test_HFLM:
generate_until_RES = [ generate_until_RES = [
" The average of $2.50 each is $", " The average of $2.50 each is $",
" A robe takes 2 bolts of blue fiber and half", " A robe takes 2 bolts of blue fiber and half",
" $50,000 in repairs.", " $50,000 in repairs.\n\nQuestion",
" He runs 1 sprint 3 times a week.", " He runs 1 sprint 3 times a week.",
" They feed each of her chickens three cups of mixed", " They feed each of her chickens three cups of mixed",
" The price of the glasses is $5, but", " The price of the glasses is $5, but",
......
...@@ -2,6 +2,7 @@ import itertools ...@@ -2,6 +2,7 @@ import itertools
import numpy as np import numpy as np
import pytest import pytest
import torch
from lm_eval.api.metrics import ( from lm_eval.api.metrics import (
aggregate_subtask_metrics, aggregate_subtask_metrics,
...@@ -258,12 +259,20 @@ class TestCollator: ...@@ -258,12 +259,20 @@ class TestCollator:
] ]
return samples return samples
def make_loglikelihood_sample_group(self, end=11):
a = [(("x", "x"), [1, 2, 3, 4, 5, 6, 7, 8], [x]) for x in range(9)]
b = [
(("x", "x"), [1, 2, 3, 4, 5, 6, 7, 8], [x, y, z])
for x, y, z in zip(range(9), range(9, 18), range(18, 27))
]
return a + b
@pytest.mark.parametrize("batch_size, end", [(17, 30), (8, 61), (12, 48), (0, 9)]) @pytest.mark.parametrize("batch_size, end", [(17, 30), (8, 61), (12, 48), (0, 9)])
def test_generations(self, batch_size, end): def test_generations(self, batch_size, end):
_collate_gen = lambda x: (-len(x[0]), x[0]) # noqa: E731 _collate_gen = lambda x: (-len(x[0]), x[0]) # noqa: E731
generation_samples = self.make_generate_sample(int(end)) generation_samples = self.make_generate_sample(int(end))
gens = Collator(generation_samples, _collate_gen, grouping=True) gens = Collator(generation_samples, _collate_gen, group_by="gen_kwargs")
chunks = gens.get_batched(n=int(batch_size), batch_fn=None) chunks = gens.get_batched(n=int(batch_size), batch_fn=None)
output = [] output = []
for chunks in chunks: for chunks in chunks:
...@@ -292,7 +301,10 @@ class TestCollator: ...@@ -292,7 +301,10 @@ class TestCollator:
def test_loglikelihood(self, batch_size, end): def test_loglikelihood(self, batch_size, end):
_collate_log = lambda x: (-len(x[1]), tuple(x[1])) # noqa: E731 _collate_log = lambda x: (-len(x[1]), tuple(x[1])) # noqa: E731
loglikelihood_samples = self.make_loglikelihood_sample(int(end)) loglikelihood_samples = self.make_loglikelihood_sample(int(end))
loglikelihoods = Collator(loglikelihood_samples, _collate_log, grouping=False) loglikelihoods = Collator(
loglikelihood_samples,
_collate_log,
)
chunks = loglikelihoods.get_batched(n=int(batch_size), batch_fn=None) chunks = loglikelihoods.get_batched(n=int(batch_size), batch_fn=None)
output = [] output = []
for chunks in chunks: for chunks in chunks:
...@@ -309,6 +321,48 @@ class TestCollator: ...@@ -309,6 +321,48 @@ class TestCollator:
reordered_output = loglikelihoods.get_original(output) reordered_output = loglikelihoods.get_original(output)
assert reordered_output == [x[1] for x in loglikelihood_samples] assert reordered_output == [x[1] for x in loglikelihood_samples]
@pytest.mark.parametrize("batch_size", [17, 8, 12, 0])
def test_context_grouping(self, batch_size):
def _collate(x):
toks = x[1] + x[2]
return -len(toks), tuple(toks)
_collate_log = _collate # noqa: E731
loglikelihood_samples = self.make_loglikelihood_sample_group()
loglikelihoods = Collator(
loglikelihood_samples,
_collate_log,
group_fn=lambda a: a[-2] + a[-1][:-1],
group_by="contexts",
)
chunks = loglikelihoods.get_batched(n=int(batch_size), batch_fn=None)
output = []
outputs_ = []
for chunks in chunks:
# check batching
if batch_size != 0:
assert len(chunks) <= batch_size
# check reorder
assert all(
len(chunks[i][1]) <= len(chunks[i - 1][1])
for i in range(1, len(chunks))
)
for x in chunks:
for request_str, cont_toks, logits in loglikelihoods.get_cache(
req_str="".join(x[0]),
cxt_toks=x[1],
cont_toks=x[2],
logits=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
.unsqueeze(0)
.unsqueeze(0),
):
output.append(x[1])
outputs_.append(cont_toks)
assert len(output) == len(outputs_)
# check indices
reordered_output = loglikelihoods.get_original(output)
assert reordered_output == [x[1] for x in loglikelihood_samples]
def test_aggregate_mean(): def test_aggregate_mean():
# test weight_by_size is respected # test weight_by_size is respected
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment