Unverified Commit 9822b06e authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge branch 'main' into weight_by_size

parents 51f27158 b177c82c
import copy import copy
import os import os
from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import List, Literal, Optional, Tuple, Union from typing import List, Literal, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import transformers import transformers
from accelerate import Accelerator, DistributedType, find_executable_batch_size from accelerate import (
Accelerator,
DistributedType,
InitProcessGroupKwargs,
find_executable_batch_size,
)
from packaging import version from packaging import version
from peft import PeftModel from peft import PeftModel
from peft import __version__ as PEFT_VERSION from peft import __version__ as PEFT_VERSION
...@@ -18,9 +24,15 @@ from transformers.models.auto.modeling_auto import ( ...@@ -18,9 +24,15 @@ from transformers.models.auto.modeling_auto import (
from lm_eval import utils from lm_eval import utils
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import LM from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import Collator, stop_sequences_criteria from lm_eval.models.utils import (
Collator,
clear_torch_cache,
get_dtype,
pad_and_concat,
stop_sequences_criteria,
)
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
...@@ -52,7 +64,7 @@ def _get_accelerate_args( ...@@ -52,7 +64,7 @@ def _get_accelerate_args(
@register_model("hf-auto", "hf", "huggingface") @register_model("hf-auto", "hf", "huggingface")
class HFLM(LM): class HFLM(TemplateLM):
""" """
An abstracted Huggingface model class. Enables usage with both models of An abstracted Huggingface model class. Enables usage with both models of
`transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes. `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.
...@@ -66,9 +78,8 @@ class HFLM(LM): ...@@ -66,9 +78,8 @@ class HFLM(LM):
def __init__( def __init__(
self, self,
pretrained: Optional[Union[str, transformers.PreTrainedModel]] = "gpt2", pretrained: Optional[Union[str, transformers.PreTrainedModel]] = "gpt2",
backend: Optional[ backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
Literal["default", "causal", "seq2seq"] # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
] = "default", # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
revision: Optional[str] = "main", revision: Optional[str] = "main",
subfolder: Optional[str] = None, subfolder: Optional[str] = None,
tokenizer: Optional[ tokenizer: Optional[
...@@ -79,6 +90,7 @@ class HFLM(LM): ...@@ -79,6 +90,7 @@ class HFLM(LM):
] ]
] = None, ] = None,
truncation: Optional[bool] = False, truncation: Optional[bool] = False,
logits_cache: bool = True,
max_length: Optional[int] = None, max_length: Optional[int] = None,
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto", dtype: Optional[Union[str, torch.dtype]] = "auto",
...@@ -86,6 +98,7 @@ class HFLM(LM): ...@@ -86,6 +98,7 @@ class HFLM(LM):
max_batch_size: Optional[int] = 64, max_batch_size: Optional[int] = 64,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True, use_fast_tokenizer: Optional[bool] = True,
add_bos_token: Optional[bool] = False,
# arguments used for splitting a model across GPUs naively. # arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`. # only used if `parallelize=True`.
parallelize: Optional[bool] = False, parallelize: Optional[bool] = False,
...@@ -108,8 +121,8 @@ class HFLM(LM): ...@@ -108,8 +121,8 @@ class HFLM(LM):
assert not parallelize, "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`" assert not parallelize, "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
self._model = pretrained self._model = pretrained
self._device = self._model.device self._device = self._model.device
self._config = self._model.config self._config = self._model.config
gpus = 0
if tokenizer: if tokenizer:
assert isinstance( assert isinstance(
...@@ -132,7 +145,8 @@ class HFLM(LM): ...@@ -132,7 +145,8 @@ class HFLM(LM):
assert isinstance(batch_size, (int, str)) assert isinstance(batch_size, (int, str))
gpus = torch.cuda.device_count() gpus = torch.cuda.device_count()
accelerator = Accelerator() accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
self.accelerator = accelerator self.accelerator = accelerator
...@@ -226,7 +240,7 @@ class HFLM(LM): ...@@ -226,7 +240,7 @@ class HFLM(LM):
) )
self.truncation = truncation self.truncation = truncation
self.logits_cache = logits_cache
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
# select (or create) a pad token to use # select (or create) a pad token to use
if self.tokenizer.pad_token: if self.tokenizer.pad_token:
...@@ -236,7 +250,7 @@ class HFLM(LM): ...@@ -236,7 +250,7 @@ class HFLM(LM):
elif self.tokenizer.eos_token: elif self.tokenizer.eos_token:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
else: else:
if self.config.model_type == "qwen": if getattr(self.config, "model_type", None) == "qwen":
# Qwen's trust_remote_code tokenizer does not allow for adding special tokens # Qwen's trust_remote_code tokenizer does not allow for adding special tokens
self.tokenizer.pad_token = "<|endoftext|>" self.tokenizer.pad_token = "<|endoftext|>"
elif ( elif (
...@@ -252,6 +266,14 @@ class HFLM(LM): ...@@ -252,6 +266,14 @@ class HFLM(LM):
else: else:
self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
# TODO: override this for Gemma
self.add_bos_token = add_bos_token
if getattr(self.config, "model_type", None) == "gemma":
self.add_bos_token = True
eval_logger.info(
f"Model type is '{self.config.model_type}', a BOS token will be used as Gemma underperforms without it."
)
self._max_length = max_length self._max_length = max_length
self.batch_schedule = 1 self.batch_schedule = 1
...@@ -372,7 +394,7 @@ class HFLM(LM): ...@@ -372,7 +394,7 @@ class HFLM(LM):
def _get_backend( def _get_backend(
self, self,
config: transformers.AutoConfig, config: Union[transformers.PretrainedConfig, transformers.AutoConfig],
backend: Optional[Literal["default", "causal", "seq2seq"]] = "default", backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
) -> None: ) -> None:
...@@ -496,13 +518,13 @@ class HFLM(LM): ...@@ -496,13 +518,13 @@ class HFLM(LM):
if transformers.__version__ >= "4.30.0": if transformers.__version__ >= "4.30.0":
if model_kwargs.get("load_in_4bit", None): if model_kwargs.get("load_in_4bit", None):
if model_kwargs.get("bnb_4bit_compute_dtype", None): if model_kwargs.get("bnb_4bit_compute_dtype", None):
model_kwargs["bnb_4bit_compute_dtype"] = utils.get_dtype( model_kwargs["bnb_4bit_compute_dtype"] = get_dtype(
model_kwargs["bnb_4bit_compute_dtype"] model_kwargs["bnb_4bit_compute_dtype"]
) )
self._model = self.AUTO_MODEL_CLASS.from_pretrained( self._model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
torch_dtype=utils.get_dtype(dtype), torch_dtype=get_dtype(dtype),
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
**model_kwargs, **model_kwargs,
) )
...@@ -617,7 +639,13 @@ class HFLM(LM): ...@@ -617,7 +639,13 @@ class HFLM(LM):
return batch_size return batch_size
batch_size = forward_batch() try:
batch_size = forward_batch()
except RuntimeError as e:
if "No executable batch size found" in str(e):
batch_size = 1
else:
raise
if self.world_size > 1: if self.world_size > 1:
# if multi-GPU, always take minimum over all selected batch sizes # if multi-GPU, always take minimum over all selected batch sizes
...@@ -626,10 +654,10 @@ class HFLM(LM): ...@@ -626,10 +654,10 @@ class HFLM(LM):
self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist() self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
) )
batch_size = min(gathered) batch_size = min(gathered)
utils.clear_torch_cache() clear_torch_cache()
return batch_size return batch_size
utils.clear_torch_cache() clear_torch_cache()
return batch_size return batch_size
def tok_encode( def tok_encode(
...@@ -638,8 +666,9 @@ class HFLM(LM): ...@@ -638,8 +666,9 @@ class HFLM(LM):
""" """ """ """
if add_special_tokens is None: if add_special_tokens is None:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False add_special_tokens = False or self.add_bos_token
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# TODO: investigate best practices for enc-dec models + special tokens
add_special_tokens = True add_special_tokens = True
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
...@@ -662,7 +691,7 @@ class HFLM(LM): ...@@ -662,7 +691,7 @@ class HFLM(LM):
self.tokenizer.padding_side = padding_side self.tokenizer.padding_side = padding_side
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False add_special_tokens = False or self.add_bos_token
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
add_special_tokens = True add_special_tokens = True
...@@ -721,6 +750,11 @@ class HFLM(LM): ...@@ -721,6 +750,11 @@ class HFLM(LM):
# and we don't want a warning from HF # and we don't want a warning from HF
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
do_sample = generation_kwargs.get("do_sample", None) do_sample = generation_kwargs.get("do_sample", None)
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
generation_kwargs["do_sample"] = do_sample = False
if do_sample is False and generation_kwargs.get("temperature") == 0.0: if do_sample is False and generation_kwargs.get("temperature") == 0.0:
generation_kwargs.pop("temperature") generation_kwargs.pop("temperature")
# build stopping criteria # build stopping criteria
...@@ -736,7 +770,9 @@ class HFLM(LM): ...@@ -736,7 +770,9 @@ class HFLM(LM):
**generation_kwargs, **generation_kwargs,
) )
def _select_cont_toks(self, logits, contlen=None, inplen=None): def _select_cont_toks(
self, logits: torch.Tensor, contlen: int = None, inplen: int = None
) -> torch.Tensor:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
assert ( assert (
contlen and inplen contlen and inplen
...@@ -754,39 +790,6 @@ class HFLM(LM): ...@@ -754,39 +790,6 @@ class HFLM(LM):
return logits return logits
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation, add_special_tokens=False)
context_enc = self.tok_encode(context, add_special_tokens=False)
# whole_enc = self.tok_encode(context + continuation)
# context_enc = self.tok_encode(context, add_special_tokens=False)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc, continuation_enc = (
[self.eot_token_id],
self.tok_encode(continuation),
)
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
loglikelihoods = [] loglikelihoods = []
...@@ -827,7 +830,7 @@ class HFLM(LM): ...@@ -827,7 +830,7 @@ class HFLM(LM):
rolling_token_windows += pad_amnt * [rolling_token_windows[0]] rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
string_nll = self._loglikelihood_tokens( string_nll = self._loglikelihood_tokens(
rolling_token_windows, requests=rolling_token_windows,
disable_tqdm=True, disable_tqdm=True,
override_bs=adaptive_batch_size, override_bs=adaptive_batch_size,
) )
...@@ -869,7 +872,7 @@ class HFLM(LM): ...@@ -869,7 +872,7 @@ class HFLM(LM):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = [] res = []
def _collate(x): def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key for the sorted method""" """Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages: # the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning # - time estimates will always be over not underestimates, which is more useful for planning
...@@ -878,10 +881,26 @@ class HFLM(LM): ...@@ -878,10 +881,26 @@ class HFLM(LM):
# automatic adaptive batches much much easier to implement # automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end # - any OOMs will happen right away rather than near the end
toks = x[1] + x[2] toks = req[1] + req[2]
return -len(toks), tuple(toks) return -len(toks), tuple(toks)
re_ord = Collator(requests, sort_fn=_collate) def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key to group and lookup one-token continuations"""
# Use with group_by="contexts" (optional)"
# allows for the creation of a lookup, so we can re-use logits in case of one-token continuations.
# speeds up some multiple-choice tasks proportionally to the number of choices.
# groups requests by context+continuation[:-1] and infer on one request/group.
return req[-2] + req[-1][:-1]
re_ord = Collator(
requests,
sort_fn=_collate,
group_by="contexts"
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
and self.logits_cache
else None,
group_fn=_lookup_one_token_cont,
)
# automatic (variable) batch size detection for vectorization # automatic (variable) batch size detection for vectorization
# pull longest context sample from request # pull longest context sample from request
...@@ -902,7 +921,11 @@ class HFLM(LM): ...@@ -902,7 +921,11 @@ class HFLM(LM):
) )
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0))) pbar = tqdm(
total=len(requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running loglikelihood requests",
)
for chunk in chunks: for chunk in chunks:
inps = [] inps = []
cont_toks_list = [] cont_toks_list = []
...@@ -979,18 +1002,18 @@ class HFLM(LM): ...@@ -979,18 +1002,18 @@ class HFLM(LM):
# create encoder attn mask and batched conts, if seq2seq # create encoder attn mask and batched conts, if seq2seq
call_kwargs = {} call_kwargs = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
batched_inps = utils.pad_and_concat( batched_inps = pad_and_concat(
padding_len_inp, inps, padding_side="right" padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp] ) # [batch, padding_len_inp]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# TODO: left-pad encoder inps and mask? # TODO: left-pad encoder inps and mask?
batched_inps = utils.pad_and_concat( batched_inps = pad_and_concat(
padding_len_inp, inps padding_len_inp, inps
) # [batch, padding_len_inp] ) # [batch, padding_len_inp]
batched_conts = utils.pad_and_concat( batched_conts = pad_and_concat(
padding_len_cont, conts padding_len_cont, conts
) # [batch, padding_len_cont] ) # [batch, padding_len_cont]
batched_encoder_mask = utils.pad_and_concat( batched_encoder_mask = pad_and_concat(
padding_len_inp, encoder_attns padding_len_inp, encoder_attns
) # [batch, padding_len_inp] ) # [batch, padding_len_inp]
call_kwargs = { call_kwargs = {
...@@ -1002,7 +1025,7 @@ class HFLM(LM): ...@@ -1002,7 +1025,7 @@ class HFLM(LM):
self._model_call(batched_inps, **call_kwargs), dim=-1 self._model_call(batched_inps, **call_kwargs), dim=-1
) # [batch, padding_length (inp or cont), vocab] ) # [batch, padding_length (inp or cont), vocab]
for (cache_key, _, _), logits, inplen, cont_toks in zip( for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
chunk, multi_logits, inplens, cont_toks_list chunk, multi_logits, inplens, cont_toks_list
): ):
# Slice to original seq length # Slice to original seq length
...@@ -1021,24 +1044,36 @@ class HFLM(LM): ...@@ -1021,24 +1044,36 @@ class HFLM(LM):
# Check if per-token argmax is exactly equal to continuation # Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1) greedy_tokens = logits.argmax(dim=-1)
cont_toks = torch.tensor(
cont_toks, dtype=torch.long, device=self.device
).unsqueeze(0) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
-1
) # [1, seq]
# Answer: (log prob, is-exact-match) # check for one-token continuation cache hits.
answer = (float(logits.sum()), bool(max_equal)) # noop in case group_by != "contexts" or no cache hit and returns the
# original args. Otherwise, expands the logits batch dimension and yields each
res.append(answer) # batch along with matching continuation tokens and prompt strings.
# logits -> [1, seq, vocab]
self.cache_hook.add_partial("loglikelihood", cache_key, answer) for request_str, cont_toks, logits in re_ord.get_cache(
pbar.update(1) req_str=request_str,
cxt_toks=ctx_tokens,
cont_toks=cont_toks,
logits=logits,
):
cont_toks = torch.tensor(
cont_toks, dtype=torch.long, device=self.device
).unsqueeze(0) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
-1
) # [1, seq]
# Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal))
res.append(answer)
self.cache_hook.add_partial("loglikelihood", request_str, answer)
pbar.update(1)
pbar.close() pbar.close()
...@@ -1047,7 +1082,7 @@ class HFLM(LM): ...@@ -1047,7 +1082,7 @@ class HFLM(LM):
def generate_until(self, requests: List[Instance]) -> List[str]: def generate_until(self, requests: List[Instance]) -> List[str]:
res = [] res = []
def _collate(x): def _collate(req: Tuple[str, dict]):
"""Defines the key for the sorted method""" """Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages: # the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning # - time estimates will always be over not underestimates, which is more useful for planning
...@@ -1055,10 +1090,15 @@ class HFLM(LM): ...@@ -1055,10 +1090,15 @@ class HFLM(LM):
# padded context length. this is useful to simplify the batching logic and more importantly to make # padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement # automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end # - any OOMs will happen right away rather than near the end
toks = self.tok_encode(x[0]) toks = self.tok_encode(req[0])
return -len(toks), x[0] return -len(toks), req[0]
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(
total=len(requests),
disable=(self.rank != 0),
desc="Running generate_until requests",
)
adaptive_batch_size = None
if self.batch_size == "auto": if self.batch_size == "auto":
# using rolling window with maximum context # using rolling window with maximum context
print("Passed argument batch_size = auto. Detecting largest batch size") print("Passed argument batch_size = auto. Detecting largest batch size")
...@@ -1082,7 +1122,13 @@ class HFLM(LM): ...@@ -1082,7 +1122,13 @@ class HFLM(LM):
# we group requests by their generation_kwargs, # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch. # in the same batch.
re_ords = Collator([reg.args for reg in requests], _collate, grouping=True) # group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
re_ords = Collator(
[reg.args for reg in requests],
sort_fn=_collate,
group_by="gen_kwargs",
group_fn=lambda x: x[1],
)
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
for chunk in chunks: for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk) contexts, all_gen_kwargs = zip(*chunk)
...@@ -1103,7 +1149,7 @@ class HFLM(LM): ...@@ -1103,7 +1149,7 @@ class HFLM(LM):
) )
else: else:
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
if not until: if not until:
until = [self.tok_decode(self.eot_token_id)] until = [self.tok_decode(self.eot_token_id)]
......
...@@ -2,7 +2,7 @@ from typing import Optional, Union ...@@ -2,7 +2,7 @@ from typing import Optional, Union
import torch import torch
from lm_eval import utils import lm_eval.models.utils
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM from lm_eval.models.huggingface import HFLM
...@@ -56,9 +56,9 @@ class MambaLMWrapper(HFLM): ...@@ -56,9 +56,9 @@ class MambaLMWrapper(HFLM):
super().__init__( super().__init__(
pretrained=pretrained, pretrained=pretrained,
# set appropriate defaults for tokenizer, max length, etc # set appropriate defaults for tokenizer, max length, etc
backend=kwargs.get("backend", "causal"), backend=kwargs.pop("backend", "causal"),
tokenizer=kwargs.get("tokenizer", "EleutherAI/gpt-neox-20b"), tokenizer=kwargs.pop("tokenizer", "EleutherAI/gpt-neox-20b"),
max_length=kwargs.get("max_length", 2048), max_length=kwargs.pop("max_length", 2048),
**kwargs, **kwargs,
) )
...@@ -97,7 +97,9 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba ...@@ -97,7 +97,9 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba
self._model = MambaLMHeadModel.from_pretrained( self._model = MambaLMHeadModel.from_pretrained(
pretrained, pretrained,
device=self._device, device=self._device,
dtype=torch.float16 if dtype == "auto" else utils.get_dtype(dtype), dtype=torch.float16
if dtype == "auto"
else lm_eval.models.utils.get_dtype(dtype),
) )
def _model_generate(self, context, max_length, stop, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
......
import copy
import json
import logging
import subprocess
from collections import defaultdict
from typing import List, Optional, Union
import torch
import torch.nn.functional as F
import transformers
from packaging import version
from tqdm import tqdm
from transformers import GenerationConfig
from transformers.generation import StoppingCriteriaList
import lm_eval.models.utils
from lm_eval import utils
from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import stop_sequences_criteria
try:
NEURON_AVAILABLE = True
from optimum.neuron import NeuronModelForCausalLM
from optimum.neuron.generation import TokenSelector
from optimum.neuron.version import __version__ as optimum_neuron_version
except ImportError:
NeuronModelForCausalLM = object
NEURON_AVAILABLE = False
logger = logging.getLogger(__name__)
def get_nc_count() -> Union[int, None]:
"""Returns the number of neuron cores on the current instance."""
try:
cmd = "neuron-ls --json-output"
result = subprocess.run(cmd, shell=True, capture_output=True)
print(f"inferring nc_count from `neuron-ls` {result.stdout}")
json_output = json.loads(result.stdout)
count = sum([x["nc_count"] for x in json_output])
print(f"nc_count={count}")
return count
except Exception:
return None
def wrap_constant_batch_size(func):
def _decorator(self, input_ids):
"""input_ids a 2D array with batch_size on dim=0
makes sure the func runs with self.batch_size
"""
# access a from TestSample
batch_size = input_ids.shape[0]
if batch_size < self.batch_size:
# handle the event of input_ids.shape[0] != batch_size
# Neuron cores expect constant batch_size
input_ids = torch.concat(
(
input_ids,
# add missing_batch_size dummy
torch.zeros(
[self.batch_size - batch_size, *input_ids.size()[1:]],
dtype=input_ids.dtype,
device=input_ids.device,
),
),
dim=0,
)
elif batch_size > self.batch_size:
raise ValueError(
f"The specified batch_size ({batch_size}) exceeds the model static batch size ({self.batch_size})"
)
# return the forward pass that requires constant batch size
return func(self, input_ids)[:batch_size]
return _decorator
class CustomNeuronModelForCausalLM(NeuronModelForCausalLM):
"""NeuronModelForCausalLM with `stopping_criteria` in `generate`"""
def generate(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
stopping_criteria: Optional["StoppingCriteriaList"] = None,
generation_config: Optional["GenerationConfig"] = None,
**kwargs,
) -> torch.LongTensor:
r"""
A streamlined generate() method overriding the transformers.GenerationMixin.generate() method.
This method uses the same logits processors/warpers and stopping criteria as the transformers library
`generate()` method but restricts the generation to greedy search and sampling.
It does not support transformers `generate()` advanced options.
Please refer to https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationMixin.generate
for details on generation configuration.
Parameters:
input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices.
generation_config (`~transformers.generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~transformers.generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
Returns:
`torch.Tensor`: A `torch.FloatTensor`.
"""
# The actual generation configuration is a combination of config and parameters
generation_config = copy.deepcopy(
self.generation_config if generation_config is None else generation_config
)
model_kwargs = generation_config.update(
**kwargs
) # All unused kwargs must be model kwargs
# Check model kwargs are actually used by either prepare_inputs_for_generation or forward
self._validate_model_kwargs(model_kwargs)
# Instantiate a TokenSelector for the specified configuration
selector = TokenSelector.create(
input_ids, generation_config, self, self.max_length
)
selector.stopping_criteria.append(stopping_criteria)
# Verify that the inputs are compatible with the model static input dimensions
batch_size, sequence_length = input_ids.shape
if sequence_length > self.max_length:
raise ValueError(
f"The input sequence length ({sequence_length}) exceeds the model static sequence length ({self.max_length})"
)
padded_input_ids = input_ids
padded_attention_mask = attention_mask
if batch_size > self.batch_size:
raise ValueError(
f"The specified batch_size ({batch_size}) exceeds the model static batch size ({self.batch_size})"
)
elif batch_size < self.batch_size:
logger.warning(
"Inputs will be padded to match the model static batch size. This will increase latency."
)
padding_shape = [self.batch_size - batch_size, sequence_length]
padding = torch.full(
padding_shape, fill_value=self.config.eos_token_id, dtype=torch.int64
)
padded_input_ids = torch.cat([input_ids, padding])
if attention_mask is not None:
padding = torch.zeros(padding_shape, dtype=torch.int64)
padded_attention_mask = torch.cat([attention_mask, padding])
# Drop the current generation context and clear the Key/Value cache
self.reset_generation()
output_ids = self.generate_tokens(
padded_input_ids,
selector,
batch_size,
attention_mask=padded_attention_mask,
**model_kwargs,
)
return output_ids[:batch_size, :]
@register_model("neuronx")
class NEURON_HF(TemplateLM):
"""
Enables usage with on AWS Neuron
using the HuggingFace Transformers + Transformers neuronx library.
Tested with neuron 2.17.0
"""
_DEFAULT_MAX_LENGTH = 2048
def __init__(
self,
pretrained: Optional[str] = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
revision: Optional[str] = "main",
tp_degree: Optional[int] = None,
subfolder: Optional[str] = None,
tokenizer: Optional[str] = None,
truncation: Optional[bool] = False,
max_length: Optional[int] = None,
dtype: Optional[Union[str, torch.dtype]] = "auto",
batch_size: Optional[int] = 1,
low_cpu_mem_usage: Optional[bool] = True,
trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True,
add_bos_token: Optional[bool] = False,
) -> None:
if not NEURON_AVAILABLE:
raise Exception(
"Tried to load neuron model, but neuron is not installed ",
"please install neuron via pip install transformers-neuron ",
"also make sure you are running on an AWS inf2 instance",
)
if version.parse(optimum_neuron_version) != version.parse("0.0.17"):
logger.warning(
'`optimum-neuron` model requires `pip install "optimum[neuronx]>=0.0.17" '
"preferably using the Hugging Face Neuron Deep Learning AMI (Ubuntu 22.04) "
"https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2 "
f"You are using optimum-neuron={optimum_neuron_version}"
)
super().__init__()
assert isinstance(pretrained, str)
assert isinstance(batch_size, (int, str))
self.batch_size_per_gpu = int(batch_size)
batch_size = int(batch_size)
if tp_degree is None:
# execute `neuron-ls --json-output | jq '.[0].nc_count'``
# to get the number of neuron cores on your instance
tp_degree = get_nc_count()
assert isinstance(tp_degree, int), (
f"model_args must include tp_degree. tp_degree must be set to an integer,"
f" but is tp_degree=`{tp_degree}` with type=`{type(tp_degree)}`."
"Set it to number of neuron cores on your instance."
" For inf2.xlarge and inf2.8xlarge, set it to `2`."
" For inf2.24xlarge, set it to `12`."
" For inf2.48xlarge, set it to `24`."
)
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")
self._config = transformers.AutoConfig.from_pretrained(
pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
)
torch_dtype = lm_eval.models.utils.get_dtype(dtype)
assert torch_dtype in [
torch.float16,
torch.bfloat16,
], "Only float16 and bfloat16 are supported"
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
revision=revision,
trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer,
)
# Neuron specific code
if torch_dtype == torch.float16:
self.amp_dtype = "f16"
elif torch_dtype == torch.bfloat16:
self.amp_dtype = "bf16"
elif torch_dtype == torch.float32:
self.amp_dtype = "f32"
else:
raise NotImplementedError("Only float16 and bfloat16 are implemented.")
compiler_args = {"num_cores": tp_degree, "auto_cast_type": self.amp_dtype}
input_shapes = {
"batch_size": batch_size,
"sequence_length": self._DEFAULT_MAX_LENGTH,
}
print(
f"{'='*20} \n loading model to neuron with"
f" {compiler_args}, {input_shapes}..."
)
self.model = CustomNeuronModelForCausalLM.from_pretrained(
pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=low_cpu_mem_usage,
export=True,
**compiler_args,
**input_shapes,
)
print(f"SUCCESS: neuron model compiled. \n {'='*20}")
self.truncation = truncation
self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.add_bos_token = self.add_bos_token
self._max_length = max_length
self.batch_schedule = 1
self.batch_sizes = {}
@property
def config(self):
# return the associated transformers.AutoConfig for the given pretrained model.
return self._config
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
if self._max_length: # if max length manually set, return it
return self._max_length
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
for attr in seqlen_config_attrs:
if hasattr(self.model.config, attr):
return getattr(self.model.config, attr)
if hasattr(self.tokenizer, "model_max_length"):
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
return self._DEFAULT_MAX_LENGTH
return self.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH
@property
def max_gen_toks(self) -> int:
return 256
@property
def batch_size(self):
return self.batch_size_per_gpu
@property
def device(self):
"""device are neuron cores, but the created tensors are on CPU."""
return "cpu"
@property
def rank(self):
return 0
@property
def world_size(self):
return 1
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None):
""" """
if add_special_tokens is None:
add_special_tokens = False or self.add_bos_token
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
encoding = encoding[-left_truncate_len:]
return encoding
def tok_batch_encode(
self,
strings: List[str],
padding_side: str = "left",
left_truncate_len: int = None,
truncation: bool = False,
):
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side
add_special_tokens = False or self.add_bos_token
encoding = self.tokenizer(
strings,
truncation=truncation,
padding="longest",
return_tensors="pt",
add_special_tokens=add_special_tokens,
)
if left_truncate_len:
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
encoding["attention_mask"] = encoding["attention_mask"][
:, -left_truncate_len:
]
self.tokenizer.padding_side = old_padding_side
return encoding["input_ids"], encoding["attention_mask"]
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
@wrap_constant_batch_size
def _model_call(self, input_ids: torch.Tensor):
"""
get logits for the entire sequence
:param input_ids: torch.Tensor
A torch tensor of shape [batch, sequence_cont]
the size of sequence may vary from call to call
:return
A torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model's decoder-lm head
"""
_, sequence_length = input_ids.shape
with torch.inference_mode():
cache_ids = torch.arange(0, sequence_length, dtype=torch.int32).split(1)
input_ids_split = input_ids.split(1, dim=1)
return torch.concat(
[
self.model.forward(
input_ids=input_id, cache_ids=cache_id, return_dict=False
)[0]
for input_id, cache_id in zip(input_ids_split, cache_ids)
],
dim=1,
)
def _model_generate(self, context, max_length, stop, **generation_kwargs):
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
with torch.inference_mode():
if "do_sample" not in generation_kwargs.keys():
generation_kwargs["do_sample"] = False
stopping_criteria = stop_sequences_criteria(
self.tokenizer,
stop + [self.tokenizer.decode([self.config.eos_token_id])],
1,
context.shape[0],
)
return self.model.generate(
input_ids=context,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
use_cache=True,
**generation_kwargs,
)
def _select_cont_toks(self, logits, contlen=None, inplen=None):
assert (
contlen and inplen
), "Must pass input len and cont. len to select scored logits for causal LM"
# discard right-padding.
# also discard the input/context tokens. we'll only score continuations.
logits = logits[inplen - contlen : inplen]
return logits
def loglikelihood_rolling(self, requests):
loglikelihoods = []
adaptive_batch_size = None
for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
)
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
pad_amnt = 0
if self.world_size > 1:
# We pad out the external document-level iterator so the inner iterator doesn't hang
mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
gathered = (
self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
)
pad_amnt = max(gathered) - gathered[self.rank]
if pad_amnt > 0:
rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
string_nll = self._loglikelihood_tokens(
rolling_token_windows,
disable_tqdm=True,
override_bs=adaptive_batch_size,
)
if (self.world_size > 1) and (pad_amnt > 0):
string_nll = [x[0] for x in string_nll[:-pad_amnt]]
else:
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False, override_bs=None
):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = x[1] + x[2]
return -len(toks), tuple(toks)
re_ord = utils.Reorderer(requests, _collate)
n_reordered_requests = len(re_ord.get_reordered()) # noqa
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
chunks = lm_eval.models.utils.chunks(
re_ord.get_reordered(),
n=self.batch_size,
fn=None,
)
for chunk in tqdm(chunks, disable=(disable_tqdm or (self.rank != 0))):
inps = []
cont_toks_list = []
inplens = []
conts = [] # noqa
encoder_attns = [] # noqa
padding_len_inp = None
padding_len_cont = None # noqa
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying
for _, context_enc, continuation_enc in chunk:
# sanity check
assert len(context_enc) > 0
assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length
# how this all works (illustrated on a causal decoder-only setup):
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# model \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
# when too long to fit in context, truncate from the left
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long,
device=self.device,
)
(inplen,) = inp.shape
padding_len_inp = (
max(padding_len_inp, inplen)
if padding_len_inp is not None
else inplen
)
inps.append(inp) # [1, inp_length]
cont_toks_list.append(continuation_enc)
inplens.append(inplen)
# create encoder attn mask and batched conts, if seq2seq
call_kwargs = {}
batched_inps = lm_eval.models.utils.pad_and_concat(
padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp]
multi_logits = F.log_softmax(
self._model_call(batched_inps, **call_kwargs), dim=-1
) # [batch, padding_length (inp or cont), vocab]
for (cache_key, _, _), logits, inplen, cont_toks in zip(
chunk, multi_logits, inplens, cont_toks_list
):
# Slice to original seq length
contlen = len(cont_toks)
# take only logits in the continuation
# (discard context toks if decoder-only ; discard right-padding)
# also discards + checks for "virtual tokens" in the causal LM's input window
# from prompt/prefix tuning tokens, if applicable
ctx_len = inplen + (logits.shape[0] - padding_len_inp)
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
logits = logits.unsqueeze(0) # [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1)
cont_toks = torch.tensor(
cont_toks, dtype=torch.long, device=self.device
).unsqueeze(0) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
-1
) # [1, seq]
# Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal))
res.append(answer)
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return re_ord.get_original(res)
def generate_until(self, requests):
res = defaultdict(list)
re_ords = {}
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(x[0])
return -len(toks), x[0]
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = lm_eval.models.utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
pbar = tqdm(total=len(requests), disable=(self.rank != 0))
# for each different set of kwargs, we execute all requests, by batch.
for key, re_ord in re_ords.items():
chunks = lm_eval.models.utils.chunks(
re_ord.get_reordered(), n=self.batch_size
)
for chunk in tqdm(chunks, disable=self.rank != 0):
contexts, all_gen_kwargs = zip(*chunk)
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# first stop sequence is used to halt generation upon encountering
primary_until = [until[0]]
max_ctx_len = self.max_length - max_gen_toks
# encode, pad, and truncate contexts for this batch
context_enc, attn_masks = self.tok_batch_encode(
contexts,
left_truncate_len=max_ctx_len,
truncation=self.truncation,
)
context_enc = context_enc.to(self.device)
attn_masks = attn_masks.to(self.device)
if "max_length" not in kwargs:
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
# perform batched generation
cont = self._model_generate(
context=context_enc,
attention_mask=attn_masks,
stop=primary_until,
**kwargs,
)
cont_toks_list = cont.tolist()
for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only LM
cont_toks = cont_toks[context_enc.shape[1] :]
s = self.tok_decode(cont_toks)
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
if len(term) > 0:
# ignore '' separator,
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s = s.split(term)[0]
res[key].append(s)
self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), s
)
pbar.update(1)
# reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key])
pbar.close()
return grouper.get_original(res)
...@@ -6,10 +6,12 @@ from typing import List, Literal, Optional, Tuple ...@@ -6,10 +6,12 @@ from typing import List, Literal, Optional, Tuple
from tqdm import tqdm from tqdm import tqdm
import lm_eval.models.utils
from lm_eval import utils from lm_eval import utils
from lm_eval.api.model import LM from lm_eval.api.model import LM, TemplateLM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import eval_logger, retry_on_specific_exceptions from lm_eval.models.utils import retry_on_specific_exceptions
from lm_eval.utils import eval_logger
def get_result(response, ctxlen: int) -> Tuple[float, bool]: def get_result(response, ctxlen: int) -> Tuple[float, bool]:
...@@ -73,7 +75,7 @@ def oa_completion(client, chat: bool = False, **kwargs): ...@@ -73,7 +75,7 @@ def oa_completion(client, chat: bool = False, **kwargs):
@register_model("openai-completions", "local-completions") @register_model("openai-completions", "local-completions")
class OpenaiCompletionsLM(LM): class OpenaiCompletionsLM(TemplateLM):
_DEFAULT_MAX_LENGTH = 2048 _DEFAULT_MAX_LENGTH = 2048
def __init__( def __init__(
...@@ -169,41 +171,12 @@ class OpenaiCompletionsLM(LM): ...@@ -169,41 +171,12 @@ class OpenaiCompletionsLM(LM):
# Isn't used because we override _loglikelihood_tokens # Isn't used because we override _loglikelihood_tokens
raise NotImplementedError() raise NotImplementedError()
def tok_encode(self, string: str) -> List[int]: def tok_encode(self, string: str, **kwargs) -> List[int]:
return self.tokenizer.encode(string) return self.tokenizer.encode(string)
def tok_decode(self, tokens: List[int]) -> str: def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens) return self.tokenizer.decode(tokens)
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc, continuation_enc = (
[self.eot_token_id],
self.tok_encode(continuation),
)
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]: ) -> List[Tuple[float, bool]]:
...@@ -219,7 +192,7 @@ class OpenaiCompletionsLM(LM): ...@@ -219,7 +192,7 @@ class OpenaiCompletionsLM(LM):
re_ord = utils.Reorderer(requests, _collate) re_ord = utils.Reorderer(requests, _collate)
for chunk in tqdm( for chunk in tqdm(
list(utils.chunks(re_ord.get_reordered(), self.batch_size)), list(lm_eval.models.utils.chunks(re_ord.get_reordered(), self.batch_size)),
disable=disable_tqdm, disable=disable_tqdm,
): ):
inps = [] inps = []
...@@ -288,14 +261,13 @@ class OpenaiCompletionsLM(LM): ...@@ -288,14 +261,13 @@ class OpenaiCompletionsLM(LM):
list(sameuntil_chunks(re_ord.get_reordered(), self.batch_size)) list(sameuntil_chunks(re_ord.get_reordered(), self.batch_size))
): ):
inps = [] inps = []
self._max_gen_toks = request_args.pop("max_gen_toks", self.max_gen_toks) self._max_gen_toks = request_args.get("max_gen_toks", self.max_gen_toks)
for context, _ in chunk: for context, _ in chunk:
context_enc = self.tok_encode(context) context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks) :] inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp) inps.append(inp)
until = request_args.pop("until", ["<|endoftext|>"]) until = request_args.get("until", ["<|endoftext|>"])
request_args.pop("do_sample", None)
request_args["temperature"] = request_args.get("temperature", 0) request_args["temperature"] = request_args.get("temperature", 0)
response = oa_completion( response = oa_completion(
...@@ -305,7 +277,11 @@ class OpenaiCompletionsLM(LM): ...@@ -305,7 +277,11 @@ class OpenaiCompletionsLM(LM):
max_tokens=self.max_gen_toks, max_tokens=self.max_gen_toks,
stop=until, stop=until,
seed=self.seed, seed=self.seed,
**request_args, **{
k: v
for k, v in request_args.items()
if k not in ["do_sample", "max_gen_toks"]
},
) )
for resp, (context, args_) in zip(response.choices, chunk): for resp, (context, args_) in zip(response.choices, chunk):
s = getattr(resp, "text") s = getattr(resp, "text")
...@@ -429,7 +405,7 @@ class OpenaiChatCompletionsLM(LM): ...@@ -429,7 +405,7 @@ class OpenaiChatCompletionsLM(LM):
# we group requests by their generation_kwargs, # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch. # in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x.args[1])) grouper = lm_eval.models.utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items(): for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending. # within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer( re_ords[key] = utils.Reorderer(
...@@ -441,7 +417,7 @@ class OpenaiChatCompletionsLM(LM): ...@@ -441,7 +417,7 @@ class OpenaiChatCompletionsLM(LM):
# n needs to be 1 because messages in # n needs to be 1 because messages in
# chat completion are not batch but # chat completion are not batch but
# is regarded as a single conversation. # is regarded as a single conversation.
chunks = utils.chunks(re_ord.get_reordered(), n=1) chunks = lm_eval.models.utils.chunks(re_ord.get_reordered(), n=1)
for chunk in chunks: for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk) contexts, all_gen_kwargs = zip(*chunk)
inps = [{"role": "user", "content": context} for context in contexts] inps = [{"role": "user", "content": context} for context in contexts]
......
...@@ -28,7 +28,7 @@ class OptimumLM(HFLM): ...@@ -28,7 +28,7 @@ class OptimumLM(HFLM):
super().__init__( super().__init__(
device=self.openvino_device, device=self.openvino_device,
backend=kwargs.get("backend", "causal"), backend=kwargs.pop("backend", "causal"),
**kwargs, **kwargs,
) )
......
...@@ -19,7 +19,7 @@ from tqdm import tqdm ...@@ -19,7 +19,7 @@ from tqdm import tqdm
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import retry_on_specific_exceptions from lm_eval.models.utils import retry_on_specific_exceptions
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
import collections
import fnmatch
import gc
import time
from functools import wraps
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Literal,
Optional,
Tuple,
Type,
Union,
)
import torch
import transformers
from lm_eval.utils import eval_logger
def chunks(iter, n: int = 0, fn=None):
"""
Divides an iterable into chunks of specified size or based on a given function.
Useful for batching
Parameters:
- iter: The input iterable to be divided into chunks.
- n: An integer representing the size of each chunk. Default is 0.
- fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
Returns:
An iterator that yields chunks of the input iterable.
Example usage:
```
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
for chunk in chunks(data, 3):
print(chunk)
```
Output:
```
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[10]
```
"""
arr = []
for i, x in enumerate(iter):
arr.append(x)
if len(arr) == (fn(i, iter) if fn else n):
yield arr
arr = []
if arr:
yield arr
class MultiChoice:
def __init__(self, choices) -> None:
self.choices = choices
# Simple wildcard support (linux filename patterns)
def __contains__(self, values) -> bool:
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0:
eval_logger.info("Available tasks to choose:")
for choice in self.choices:
eval_logger.info(f" - {choice}")
raise ValueError("'{}' is not in task list".format(value))
return True
def __iter__(self) -> Iterator:
for choice in self.choices:
yield choice
class Grouper:
"""
takes an array `arr` and function `fn` and returns a dictionary
with keys fn(ob) for each ob in `arr` and with values `self.arr[key]` a list of all
objects in `arr` satisfying `key == fn(ob)`.
"""
def __init__(self, arr, fn) -> None:
# self.orig_arr = arr
self.size = len(arr)
arr = list(enumerate(arr))
def group_return_dict(arr, fn):
res = collections.defaultdict(list)
for ob in arr:
res[fn(ob)].append(ob)
return res
arr = group_return_dict(arr, lambda x: fn(x[1]))
# self.arr has format Dict[Tuple[int, <entry from orig. arr>]]
self.arr = arr
self._grouped = None
def get_grouped(self):
# return the contents but not indices for our grouped dict.
if self._grouped:
return self._grouped
grouped = {}
for key in self.arr.keys():
# drop the index from each element of self.arr
grouped[key] = [y[1] for y in self.arr[key]]
self._grouped = grouped
return grouped
def get_original(self, grouped_dict):
# take in a grouped dictionary with e.g. results for each key listed
# in the same order as the instances in `self.arr`, and
# return the results in the same (single list) order as `self.orig_arr`.
res = [None] * self.size
cov = [False] * self.size
# orig = [None] * self.size
assert grouped_dict.keys() == self.arr.keys()
for key in grouped_dict.keys():
for (ind, _), v in zip(self.arr[key], grouped_dict[key]):
res[ind] = v
cov[ind] = True
# orig[ind] = _
assert all(cov)
# assert orig == self.orig_arr
return res
def pad_and_concat(
max_length: int,
tensors: List[torch.Tensor],
padding_side: Literal["right", "left"] = "right",
):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
seq2seq models.
"""
assert (
padding_side == "left" or padding_side == "right"
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors):
if len(tensor.shape) == 2:
tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size
tensor_len = tensor.shape[0]
if tensor_len < max_length:
if padding_side == "right":
# right-pad
tensors[i] = torch.cat(
[
tensor, # [seq]
torch.zeros(
max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq]
],
dim=0,
).unsqueeze(0)
else:
# left-pad
tensors[i] = torch.cat(
[
torch.zeros(
max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq]
tensor, # [seq]
],
dim=0,
).unsqueeze(0)
else:
tensors[i] = tensor.unsqueeze(0)
return torch.cat(tensors, dim=0)
def clear_torch_cache() -> None:
gc.collect()
torch.cuda.empty_cache()
def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""
def __init__(
self,
sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int,
batch_size: int,
) -> None:
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
# print(sequence, self.sequence_ids)
# we look back for 2 more tokens than it takes to encode our stop sequence
# because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
# and we don't want to mistakenly not stop a generation because our
# (string) stop sequence was output in a different tokenization
# NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
# and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
# Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described.
self.sequence_id_len = len(self.sequence_ids) + 2
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :]
lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker
def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
stop_sequences: List[str],
initial_decoder_input_length: int,
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList(
[
*[
MultiTokenEOSCriteria(
sequence, tokenizer, initial_decoder_input_length, batch_size
)
for sequence in stop_sequences
],
]
)
def divide(iterable, n) -> List[Iterator]:
"""Divide the elements from *iterable* into *n* parts, maintaining
order.
>>> group_1, group_2 = divide([1, 2, 3, 4, 5, 6], 2)
>>> list(group_1)
[1, 2, 3]
>>> list(group_2)
[4, 5, 6]
If the length of *iterable* is not evenly divisible by *n*, then the
length of the returned iterables will not be identical:
>>> children = divide([1, 2, 3, 4, 5, 6, 7], 3)
>>> [list(c) for c in children]
[[1, 2, 3], [4, 5], [6, 7]]
If the length of the iterable is smaller than n, then the last returned
iterables will be empty:
>>> children = divide([1, 2, 3], 5)
>>> [list(c) for c in children]
[[1], [2], [3], [], []]
This function will exhaust the iterable before returning and may require
significant storage. If order is not important, see :func:`distribute`,
which does not first pull the iterable into memory.
"""
if n < 1:
raise ValueError("n must be at least 1")
try:
iterable[:0]
except TypeError:
seq = tuple(iterable)
else:
seq = iterable
q, r = divmod(len(seq), n)
ret = []
stop = 0
for i in range(1, n + 1):
start = stop
stop += q + 1 if i <= r else q
ret.append(iter(seq[start:stop]))
return ret
def retry_on_specific_exceptions(
on_exceptions: List[Type[Exception]],
max_retries: Optional[int] = None,
backoff_time: float = 3.0,
backoff_multiplier: float = 1.5,
on_exception_callback: Optional[Callable[[Exception, float], Any]] = None,
):
"""Retry on an LLM Provider's rate limit error with exponential backoff
For example, to use for OpenAI, do the following:
```
from openai import RateLimitError
# Recommend specifying max_retries to avoid infinite loops!
@retry_on_specific_exceptions([RateLimitError], max_retries=3)
def completion(...):
# Wrap OpenAI completion function here
...
```
"""
def decorator(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
sleep_time = backoff_time
attempt = 0
while max_retries is None or attempt < max_retries:
try:
return func(*args, **kwargs)
except tuple(on_exceptions) as e:
if on_exception_callback is not None:
on_exception_callback(e, sleep_time)
time.sleep(sleep_time)
sleep_time *= backoff_multiplier
attempt += 1
return wrapper
return decorator
class Collator:
"""
A class for reordering and batching elements of an array.
This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data.
Objects of this class have the group_by attribute which determines the method for grouping
the data while batching it. Three options include "gen_kwargs", "contexts", or None:
If group_by == "gen_kwargs" then requests will be grouped by gen_kwargs
If group_by == "contexts" then requests will be grouped by context + cont[:-1]
If None then requests will just be reordered by length descending.
"""
def __init__(
self,
arr: List,
sort_fn: Callable = lambda x: x,
group_fn: Callable = lambda x: x[1],
group_by: Union[Literal["gen_kwargs", "contexts"], None] = None,
) -> None:
self._group_by = group_by
# 0 indices are enumerated indices. Apply functions to original arr.
self._sort_fn = lambda x: sort_fn(x[1])
self._group_fn = lambda x: group_fn(x[1])
self._reorder_indices: List = []
self._size = len(arr)
self._arr_with_indices: Union[Dict, Tuple[Tuple[int, Any], ...]] = tuple(
enumerate(arr)
) # [indices, (arr)]
if self._group_by == "contexts":
self._group_by_context()
elif self._group_by == "gen_kwargs":
self._group_by_index()
def _group_by_index(self) -> None:
"""Group the elements of a list based on their indices."""
self._arr_with_indices = self.group(
self._arr_with_indices, fn=self._group_fn, group_by="gen_kwargs"
)
def _group_by_context(self) -> None:
"""Group the array with indices by context."""
self._arr_with_indices = self.group(
self._arr_with_indices, fn=self._group_fn, group_by="contexts"
)
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
"""
Generates and yields batches from the reordered array. The method of grouping and batching
depends on the parameter `group_by`.
If `group_by` is set to "gen_kwargs", it will batch the
re-ordered values with same gen_kwargs for each batch.
If `group_by` is "contexts", it caches the requests by context before batching.
If `group_by` is neither "gen_kwargs" nor "contexts", it yields the reordered array
Parameters:
- n (int): The size of each batch. Defaults to 1.
- batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of
each batch. Optional, defaults to None.
Returns:
Iterator: An iterator over batches of reordered elements grouped as per the `group_by`
attribute.
Yields:
List of batched elements according to the `group_by` attribute.
"""
if self._group_by == "gen_kwargs":
for (
key,
values,
) in self._arr_with_indices.items(): # type: ignore
values = self._reorder(values)
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
elif self._group_by == "contexts":
# Get one sample from each key
values = self._reorder(
[value[0] for value in self._arr_with_indices.values()]
)
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
else:
values = self._reorder(self._arr_with_indices) # type: ignore
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
def get_cache(
self,
req_str: Tuple[str, str] = None,
cxt_toks: List[int] = None,
cont_toks: List[int] = None,
logits: torch.Tensor = None,
) -> Iterator[Tuple[Tuple[str, str], List[int], torch.Tensor]]:
"""
Retrieves cached single-token continuations and their associated arguments, updating indices as necessary.
The behavior of this function varies depending on how the `group_by` attribute is set:
- When `group_by` is "contexts":
The function identifies single-token continuations by checking for keys that equate to
[context+continuation][-1] and logs the indices for re-ordering.
In this mode, this function can work in two scenarios:
1. Cache Hit - Single Match:
If a single matching context-continuation pair is found in the cache,
the function yields the original arguments.
2. Cache Hit - Multiple Matches:
If multiple matching context-continuation pairs are found in the cache,
the function expands the logits batch dimension to match the number of cache hits.
It updates the original requests and continuation tokens.
- When `group_by` is not set to "contexts":
This method yields the original arguments, logits and continuation tokens,
without checking for one-token continuations.
Parameters:
- req_str (tuple[str, str]): Original strings used for CachingLM.
- cxt_toks (list[int]): Full context tokens used for lookup.
- cont_toks (list[int]): Continuation tokens for which logits were generated.
- logits (torch.Tensor [1, seq_length, vocab_size]): Logits generated by the model given context and continuation keys.
Yields:
- Iterator:
- req_str (tuple[str, str]): strings used for CachingLM.
- cont_toks (list[int]) : continuation tokens.
- logits (torch.Tensor [1, seq_length, vocab_size]): The original logits (repeated cache hit times)
"""
if self._group_by == "contexts":
cache_hit: List[
Tuple[int, Tuple[Tuple[str, str], List[int], List[int]]]
] = self._arr_with_indices.pop(tuple(cxt_toks + cont_toks[:-1]))
if (cache_size := len(cache_hit)) == 1:
self._reorder_indices.extend(x[0] for x in cache_hit)
yield req_str, cont_toks, logits
else:
# If we have matching requests then expand the batch dimension (no-op) and
# yield each along with its corresponding args.
multilogits = logits.expand(cache_size, -1, -1).chunk(cache_size)
indices, req_str, cont_toks = zip(
*[(x[0], x[1][0], x[-1][-1]) for x in cache_hit]
)
self._reorder_indices.extend(indices)
for c_key, cont_tok, logit in zip(req_str, cont_toks, multilogits):
yield c_key, cont_tok, logit
else:
yield req_str, cont_toks, logits
def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> Iterator:
"""
Reorders the elements in the array based on the sorting function.
Parameters:
- arr (list | tuple[tuple[int, Any], ...]]): The array or iterable to be reordered.
Yields:
Iterator
"""
arr = sorted(arr, key=self._sort_fn)
if not self._group_by == "contexts":
# If grouped by contexts then indices will be set in get_cache()
self._reorder_indices.extend([x[0] for x in arr])
yield from [x[1] for x in arr]
def get_original(self, newarr: List) -> List:
"""
Restores the original order of elements from the reordered list.
Parameters:
- newarr (list): The reordered array.
Returns:
list: The array with elements restored to their original order.
"""
res = [None] * self._size
cov = [False] * self._size
for ind, v in zip(self._reorder_indices, newarr):
res[ind] = v
cov[ind] = True
assert all(cov)
return res
def __len__(self):
return self._size
@staticmethod
def group(
arr: Iterable,
fn: Callable,
group_by: Literal["gen_kwargs", "contexts"] = "gen_kwargs",
) -> dict:
"""
Groups elements of an iterable based on a provided function.
The `group_by` parameter determines the method of grouping.
If `group_by` is "contexts", the elements are grouped by [context + cont][:-1].
If `group_by` is "gen_kwargs", the elements are grouped based on the gen_kwargs dict.
Parameters:
- arr (Iterable): The iterable to be grouped.
- fn (Callable): The function to determine the grouping.
- values (bool): If True, returns the values of the group. Defaults to False.
Returns:
Iterator: An iterable of grouped elements.
"""
res = collections.defaultdict(list)
for ob in arr:
# where ob == [context + cont]
if group_by == "contexts":
res[tuple(fn(ob))].append(ob)
else:
try:
hashable_dict = tuple(
(
key,
tuple(value)
if isinstance(value, collections.abc.Iterable)
else value,
)
for key, value in sorted(fn(ob).items())
)
res[hashable_dict].append(ob)
except (TypeError, AttributeError):
res[tuple(fn(ob))].append(ob)
return res
@staticmethod
def get_chunks(_iter, n: int = 0, fn=None):
"""
Divides an iterable into chunks of specified size or based on a given function.
Useful for batching
Parameters:
- iter: The input iterable to be divided into chunks.
- n: An integer representing the size of each chunk. Default is 0.
- fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
Returns:
An iterator that yields chunks of the input iterable.
Example usage:
```
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
for chunk in chunks(data, 3):
print(chunk)
```
Output:
```
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[10]
```
"""
arr = []
_iter = tuple(_iter)
for i, x in enumerate(_iter):
arr.append(x)
if len(arr) == (fn(i, _iter) if fn else n):
yield arr
arr = []
if arr:
yield arr
...@@ -5,11 +5,10 @@ from typing import List, Literal, Optional, Tuple, Union ...@@ -5,11 +5,10 @@ from typing import List, Literal, Optional, Tuple, Union
from tqdm import tqdm from tqdm import tqdm
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import LM from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.utils import Collator, divide
from lm_eval.utils import ( from lm_eval.utils import (
Collator,
divide,
eval_logger, eval_logger,
get_rolling_token_windows, get_rolling_token_windows,
make_disjoint_window, make_disjoint_window,
...@@ -36,7 +35,7 @@ def run_inference_one_model( ...@@ -36,7 +35,7 @@ def run_inference_one_model(
@register_model("vllm") @register_model("vllm")
class VLLM(LM): class VLLM(TemplateLM):
_DEFAULT_MAX_LENGTH = 2048 _DEFAULT_MAX_LENGTH = 2048
def __init__( def __init__(
...@@ -48,6 +47,7 @@ class VLLM(LM): ...@@ -48,6 +47,7 @@ class VLLM(LM):
tokenizer: Optional[str] = None, tokenizer: Optional[str] = None,
tokenizer_mode: Literal["auto", "slow"] = "auto", tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
add_bos_token: Optional[bool] = False,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
quantization: Optional[str] = None, quantization: Optional[str] = None,
max_gen_toks: int = 256, max_gen_toks: int = 256,
...@@ -115,6 +115,7 @@ class VLLM(LM): ...@@ -115,6 +115,7 @@ class VLLM(LM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision, tokenizer_revision=tokenizer_revision,
) )
self.add_bos_token = add_bos_token
self._max_gen_toks = max_gen_toks self._max_gen_toks = max_gen_toks
...@@ -148,10 +149,12 @@ class VLLM(LM): ...@@ -148,10 +149,12 @@ class VLLM(LM):
self, self,
string: str, string: str,
left_truncate_len=None, left_truncate_len=None,
add_special_tokens=False, add_special_tokens=None,
truncation=False, truncation=False,
): ):
""" """ """ """
if not add_special_tokens:
add_special_tokens = False or self.add_bos_token
encoding = self.tokenizer.encode( encoding = self.tokenizer.encode(
string, add_special_tokens=add_special_tokens, truncation=truncation string, add_special_tokens=add_special_tokens, truncation=truncation
) )
...@@ -195,37 +198,6 @@ class VLLM(LM): ...@@ -195,37 +198,6 @@ class VLLM(LM):
) )
return outputs return outputs
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation, add_special_tokens=False)
context_enc = self.tok_encode(context, add_special_tokens=False)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc, continuation_enc = (
[self.eot_token_id],
self.tok_encode(continuation),
)
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
loglikelihoods = [] loglikelihoods = []
...@@ -277,12 +249,16 @@ class VLLM(LM): ...@@ -277,12 +249,16 @@ class VLLM(LM):
# we group requests by their generation_kwargs, # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch. # in the same batch.
re_ords = Collator(requests, _collate_gen, grouping=True) re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs")
chunks = re_ords.get_batched( chunks = re_ords.get_batched(
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
) )
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(
total=len(requests),
disable=(self.rank != 0),
desc="Running generate_until requests",
)
# for each different set of kwargs, we execute all requests, by batch. # for each different set of kwargs, we execute all requests, by batch.
for chunk in chunks: for chunk in chunks:
context_and_encoding, all_gen_kwargs = zip(*chunk) context_and_encoding, all_gen_kwargs = zip(*chunk)
...@@ -357,7 +333,11 @@ class VLLM(LM): ...@@ -357,7 +333,11 @@ class VLLM(LM):
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
) )
pbar = tqdm(total=len(requests), disable=disable_tqdm) pbar = tqdm(
total=len(requests),
disable=disable_tqdm,
desc="Running loglikelihood requests",
)
for chunk in chunks: for chunk in chunks:
inputs = [] inputs = []
ctxlens = [] ctxlens = []
......
import os
import ast import ast
import os
from typing import Dict from typing import Dict
from lm_eval import utils from lm_eval import utils
from lm_eval.utils import eval_logger from lm_eval.utils import eval_logger
# Prompt library. # Prompt library.
# Stores prompts in a dictionary indexed by 2 levels: # Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name. # prompt category name, and prompt name.
......
import abc
import collections
import logging
import os import os
import yaml from functools import partial
from typing import List, Union, Dict from typing import Dict, List, Union
from lm_eval import utils from lm_eval import utils
from lm_eval import prompts from lm_eval.api.task import ConfigurableTask, Task
from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
from lm_eval.api.registry import (
register_task,
register_group,
TASK_REGISTRY,
GROUP_REGISTRY,
ALL_TASKS,
)
import logging
# import python tasks class TaskManager:
from .squadv2.task import SQuAD2 """TaskManager indexes all tasks from the default `lm_eval/tasks/`
from .scrolls.task import ( and an optional directory if provided.
QuALITY,
NarrativeQA,
ContractNLI,
GovReport,
SummScreenFD,
QMSum,
)
eval_logger = utils.eval_logger
def register_configurable_task(config: Dict[str, str]) -> int:
SubClass = type(
config["task"] + "ConfigurableTask",
(ConfigurableTask,),
{"CONFIG": TaskConfig(**config)},
)
if "task" in config: """
task_name = "{}".format(config["task"])
register_task(task_name)(SubClass)
if "group" in config: def __init__(self, verbosity="INFO", include_path=None) -> None:
if config["group"] == config["task"]: self.verbosity = verbosity
raise ValueError("task and group name cannot be the same") self.include_path = include_path
elif isinstance(config["group"], str): self.logger = utils.eval_logger
group_name = [config["group"]] self.logger.setLevel(getattr(logging, f"{verbosity}"))
self._task_index = self.initialize_tasks(include_path=include_path)
self._all_tasks = sorted(list(self._task_index.keys()))
self.task_group_map = collections.defaultdict(list)
def initialize_tasks(self, include_path: str = None):
"""Creates an dictionary of tasks index.
:param include_path: str = None
An additional path to be searched for tasks
:return
Dictionary of task names as key and task metadata
"""
all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
if include_path is not None:
if isinstance(include_path, str):
include_path = [include_path]
all_paths.extend(include_path)
task_index = {}
for task_dir in all_paths:
tasks = self._get_task_and_group(task_dir)
task_index = {**tasks, **task_index}
return task_index
@property
def all_tasks(self):
return self._all_tasks
@property
def task_index(self):
return self._task_index
def match_tasks(self, task_list):
return utils.pattern_match(task_list, self.all_tasks)
def _name_is_registered(self, name):
if name in self.all_tasks:
return True
return False
def _name_is_task(self, name) -> bool:
if self._name_is_registered(name) and ("task" in self.task_index[name]["type"]):
return True
return False
def _name_is_group(self, name):
if self._name_is_registered(name) and (
self.task_index[name]["type"] == "group"
):
return True
return False
def _name_is_python_task(self, name):
if self._name_is_registered(name) and (
self.task_index[name]["type"] == "python_task"
):
return True
return False
def _config_is_task(self, config):
if ("task" in config) and isinstance(config["task"], str):
return True
return False
def _config_is_group(self, config):
if ("task" in config) and isinstance(config["task"], list):
return True
return False
def _config_is_python_task(self, config):
if "class" in config:
return True
return False
def _get_yaml_path(self, name):
assert name in self.task_index
return self.task_index[name]["yaml_path"]
def _get_config(self, name):
assert name in self.task_index
yaml_path = self._get_yaml_path(name)
if yaml_path == -1:
return {}
else: else:
group_name = config["group"] return utils.load_yaml_config(yaml_path, mode="full")
for group in group_name: def _get_tasklist(self, name):
register_group(group)(SubClass) assert self._name_is_task(name) is False
return self.task_index[name]["task"]
def _process_alias(self, config, group=None):
# If the group is not the same as the original
# group which the group alias was intended for,
# Set the group_alias to None instead.
if ("group_alias" in config) and ("group" in config) and group is not None:
if config["group"] != group:
config["group_alias"] = None
return config
def _load_individual_task_or_group(
self,
name_or_config: Union[str, dict] = None,
parent_name: str = None,
update_config: dict = None,
yaml_path: str = None,
) -> ConfigurableTask:
def load_task(config, task, group=None, yaml_path=None):
if "include" in config:
assert yaml_path is not None
config.update(
utils.load_yaml_config(
yaml_path,
yaml_config={"include": config.pop("include")},
mode="full",
)
)
if self._config_is_python_task(config):
task_object = config["class"]()
else:
config = self._process_alias(config, group=group)
task_object = ConfigurableTask(config=config)
if group is not None:
task_object = (group, task_object)
return {task: task_object}
if isinstance(name_or_config, str):
if update_config is not None:
# Process name_or_config as a dict instead
name_or_config = {"task": name_or_config, **update_config}
elif self._name_is_task(name_or_config):
task_config = self._get_config(name_or_config)
return load_task(task_config, task=name_or_config, group=parent_name)
else:
group_name = name_or_config
subtask_list = self._get_tasklist(name_or_config)
if subtask_list == -1:
group_config = self._get_config(name_or_config)
subtask_list = group_config["task"]
# This checks if we're at the root.
if parent_name is None:
group_config = self._get_config(name_or_config)
if set(group_config.keys()) > set(["task", "group"]):
update_config = {
k: v
for k, v in group_config.items()
if k not in ["task", "group"]
}
yaml_path = self._get_yaml_path(group_name)
return 0 if (update_config is not None) and ("group_alias" in update_config):
group_name = update_config["group_alias"]
update_config.pop("group_alias")
if isinstance(name_or_config, dict):
if update_config is not None:
name_or_config = {
**name_or_config,
**update_config,
}
def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -> int: if self._config_is_task(name_or_config):
group = config["group"] name = name_or_config["task"]
all_task_list = config["task"] # If the name is registered as a group
config_list = [task for task in all_task_list if not isinstance(task, str)] # if self._name_is_task(name) is False:
task_list = [task for task in all_task_list if isinstance(task, str)] if self._name_is_group(name):
group_name = name
for task_config in config_list: update_config = {
k: v for k, v in name_or_config.items() if k != "task"
base_config = {} }
task_name_config = {} subtask_list = self._get_tasklist(name)
if "task" in task_config: if subtask_list == -1:
task_name = task_config["task"] subtask_list = self._get_config(name)["task"]
if task_name in ALL_TASKS: else:
task_obj = TASK_REGISTRY[task_name] if self._name_is_registered(name):
if isinstance(task_obj, tuple): base_task_config = self._get_config(name)
_, task_obj = task_obj
# Check if this is a duplicate.
if task_obj is not None: if parent_name is not None:
base_config = task_obj.CONFIG.to_dict(keep_callable=True) name_or_config["group"] = parent_name
task_name_config["task"] = f"{group}_{task_name}" num_duplicate = len(
list(
task_config = utils.load_yaml_config(yaml_path, task_config) filter(
var_configs = check_prompt_config( lambda x: x.startswith(name),
{ self.task_group_map[parent_name],
**base_config, )
**task_config, )
**{"group": group}, )
**task_name_config, if num_duplicate > 0:
}, name = f"{name}-{num_duplicate}"
yaml_path=os.path.dirname(yaml_path), self.task_group_map[parent_name].append(name)
)
for config in var_configs: task_config = {
register_configurable_task(config) **base_task_config,
**name_or_config,
task_names = utils.pattern_match(task_list, ALL_TASKS) }
for task in task_names: else:
if (task in TASK_REGISTRY) or (task in GROUP_REGISTRY): task_config = name_or_config
if group in GROUP_REGISTRY: return load_task(
GROUP_REGISTRY[group].append(task) task_config, task=name, group=parent_name, yaml_path=yaml_path
)
else: else:
GROUP_REGISTRY[group] = [task] group_name = name_or_config["group"]
ALL_TASKS.add(group) subtask_list = name_or_config["task"]
if set(name_or_config.keys()) > set(["task", "group"]):
return 0 update_config = {
k: v
for k, v in name_or_config.items()
if k not in ["task", "group"]
}
all_subtasks = {}
if parent_name is not None:
all_subtasks = {group_name: (parent_name, None)}
def check_prompt_config( fn = partial(
config: Dict[str, str], yaml_path: str = None self._load_individual_task_or_group,
) -> List[Dict[str, str]]: parent_name=group_name,
all_configs = [] update_config=update_config,
if "use_prompt" in config:
prompt_list = prompts.load_prompt_list(
use_prompt=config["use_prompt"],
dataset_name=config["dataset_path"],
subset_name=config["dataset_name"] if "dataset_name" in config else None,
yaml_path=yaml_path, yaml_path=yaml_path,
) )
for idx, prompt_variation in enumerate(prompt_list): all_subtasks = {
all_configs.append( **all_subtasks,
{ **dict(collections.ChainMap(*map(fn, subtask_list))),
**config, }
**{"use_prompt": prompt_variation}, return all_subtasks
**{
"task": "_".join( def load_task_or_group(self, task_list: Union[str, list] = None) -> dict:
[ """Loads a dictionary of task objects from a list
config["task"]
if "task" in config :param task_list: Union[str, list] = None
else get_task_name_from_config(config), Single string or list of string of task names to be loaded
prompt_variation.split("/")[-1]
if ".yaml" in prompt_variation :return
else prompt_variation, Dictionary of task objects
] """
) if isinstance(task_list, str):
}, task_list = [task_list]
**{"output_type": "generate_until"},
} all_loaded_tasks = dict(
) collections.ChainMap(*map(self._load_individual_task_or_group, task_list))
else: )
all_configs.append(config) return all_loaded_tasks
return all_configs
def load_config(self, config: Dict):
return self._load_individual_task_or_group(config)
def get_task_name_from_config(task_config: Dict[str, str]) -> str:
if "dataset_name" in task_config: def _get_task_and_group(self, task_dir: str):
return "{dataset_path}_{dataset_name}".format(**task_config) """Creates an dictionary of tasks index with the following metadata,
else: - `type`, that can be either `task`, `python_task`, or `group`.
return "{dataset_path}".format(**task_config) `task` refer to regular task configs, `python_task` are special
yaml files that only consists of `task` and `class` parameters.
`group` are group configs.
- `yaml_path`, path to the yaml file. If the entry is a `group` that
was configured through a task config, the yaml_path will be -1
and all subtasks will be listed in `task` (see below)
- `task`, reserved for entries with `type` as `group`. This will list
all subtasks. When a group config is created (as opposed to task
config having `group` parameter set), this will be set to -1 to
avoid recursive indexing. The whole list of subtasks will be loaded
at evaluation.
:param task_dir: str
A directory to check for tasks
:return
Dictionary of task names as key and task metadata
"""
tasks_and_groups = collections.defaultdict()
for root, _, file_list in os.walk(task_dir):
for f in file_list:
if f.endswith(".yaml"):
yaml_path = os.path.join(root, f)
config = utils.load_yaml_config(yaml_path, mode="simple")
if self._config_is_python_task(config):
# This is a python class config
tasks_and_groups[config["task"]] = {
"type": "python_task",
"yaml_path": yaml_path,
}
elif self._config_is_group(config):
# This is a group config
tasks_and_groups[config["group"]] = {
"type": "group",
"task": -1, # This signals that
# we don't need to know
# the task list for indexing
# as it can be loaded
# when called.
"yaml_path": yaml_path,
}
# # Registered the level 1 tasks from a group config
# for config in config["task"]:
# if isinstance(config, dict) and self._config_is_task(config):
# task = config["task"]
# tasks_and_groups[task] = {
# "type": "task",
# "yaml_path": yaml_path,
# }
elif self._config_is_task(config):
# This is a task config
task = config["task"]
tasks_and_groups[task] = {
"type": "task",
"yaml_path": yaml_path,
}
def include_task_folder(task_dir: str, register_task: bool = True) -> None: if "group" in config:
""" groups = config["group"]
Calling this function if isinstance(config["group"], str):
""" groups = [groups]
# Track whether any tasks failed during loading for group in groups:
import_fail = False if group not in tasks_and_groups:
for root, subdirs, file_list in os.walk(task_dir): tasks_and_groups[group] = {
# if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0): "type": "group",
for f in file_list: "task": [task],
if f.endswith(".yaml"): "yaml_path": -1,
yaml_path = os.path.join(root, f) }
try: else:
config = utils.load_yaml_config(yaml_path) tasks_and_groups[group]["task"].append(task)
else:
if "task" not in config: self.logger.debug(f"File {f} in {root} could not be loaded")
continue
all_configs = check_prompt_config(
config, yaml_path=os.path.dirname(yaml_path)
)
for config in all_configs:
if register_task:
if isinstance(config["task"], str):
register_configurable_task(config)
else:
if isinstance(config["task"], list):
register_configurable_group(config, yaml_path)
# Log this silently and show it only when
# the user defines the appropriate verbosity.
except (ImportError, ModuleNotFoundError) as e:
import_fail = True
eval_logger.debug(
f"{yaml_path}: {e}. Config will not be added to registry."
)
except Exception as error:
import traceback
eval_logger.warning(
"Unexpected error loading config in\n"
f" {yaml_path}\n"
" Config will not be added to registry\n"
f" Error: {error}\n"
f" Traceback: {traceback.format_exc()}"
)
if import_fail: return tasks_and_groups
eval_logger.warning(
"Some tasks could not be loaded due to missing dependencies."
" Run with `--verbosity DEBUG` for full details."
)
return 0
def include_path(task_dir): def include_path(task_dir):
include_task_folder(task_dir) logger = utils.eval_logger
# Register Benchmarks after all tasks have been added logger.setLevel(getattr(logging, "INFO"))
include_task_folder(task_dir, register_task=False) logger.info(
"To still use tasks loaded from args.include_path,"
"see an example of the new TaskManager API in https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage"
)
return 0 return 0
def initialize_tasks(verbosity="INFO"): def initialize_tasks(verbosity="INFO"):
eval_logger.setLevel(getattr(logging, f"{verbosity}")) logger = utils.eval_logger
logger.setLevel(getattr(logging, f"{verbosity}"))
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/" logger.info(
include_path(task_dir) "lm_eval.tasks.initialize_tasks() is deprecated and no longer necessary. "
"It will be removed in v0.4.2 release. "
"TaskManager will instead be used."
)
return 0
def get_task(task_name, config): def get_task_name_from_config(task_config: Dict[str, str]) -> str:
try: if "task" in task_config:
return TASK_REGISTRY[task_name](config=config) return task_config["task"]
except KeyError: if "dataset_name" in task_config:
eval_logger.info("Available tasks:") return "{dataset_path}_{dataset_name}".format(**task_config)
eval_logger.info(list(TASK_REGISTRY) + list(GROUP_REGISTRY)) else:
raise KeyError(f"Missing task {task_name}") return "{dataset_path}".format(**task_config)
def get_task_name_from_object(task_object): def get_task_name_from_object(task_object):
for name, class_ in TASK_REGISTRY.items(): if hasattr(task_object, "config"):
if class_ is task_object: return task_object._config["task"]
return name
# TODO: scrap this # TODO: scrap this
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
...@@ -235,53 +396,44 @@ def get_task_name_from_object(task_object): ...@@ -235,53 +396,44 @@ def get_task_name_from_object(task_object):
) )
# TODO: pass num_fewshot and other cmdline overrides in a better way def get_task_dict(
def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs): task_name_list: List[Union[str, Dict, Task]], task_manager: TaskManager = None
config = {**kwargs} ):
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
:param task_name_list: List[Union[str, Dict, Task]]
Name of model or LM object, see lm_eval.models.get_model
:param task_manager: TaskManager = None
A TaskManager object that stores indexed tasks. If not set,
task_manager will load one. This should be set by the user
if there are additional paths that want to be included
via `include_path`
task_name_from_registry_dict = {} :return
Dictionary of task objects
"""
task_name_from_string_dict = {}
task_name_from_config_dict = {} task_name_from_config_dict = {}
task_name_from_object_dict = {} task_name_from_object_dict = {}
if not isinstance(task_name_list, list): if isinstance(task_name_list, str):
task_name_list = [task_name_list] task_name_list = [task_name_list]
for task_element in task_name_list: string_task_name_list = [task for task in task_name_list if isinstance(task, str)]
if isinstance(task_element, str): others_task_name_list = [task for task in task_name_list if ~isinstance(task, str)]
if task_element in GROUP_REGISTRY: if len(string_task_name_list) > 0:
group_name = task_element if task_manager is None:
for task_name in GROUP_REGISTRY[task_element]: task_manager = TaskManager()
if task_name not in task_name_from_registry_dict:
task_obj = get_task_dict(task_name)
if task_name in task_obj.keys():
task_dict = {
task_name: (group_name, task_obj[task_name]),
}
else:
task_dict = {
task_name: (group_name, None),
**task_obj,
}
task_name_from_registry_dict = {
**task_name_from_registry_dict,
**task_dict,
}
else:
task_name = task_element
if task_name not in task_name_from_registry_dict:
task_name_from_registry_dict = {
**task_name_from_registry_dict,
task_name: get_task(task_name=task_element, config=config),
}
elif isinstance(task_element, dict): task_name_from_string_dict = task_manager.load_task_or_group(
task_element.update(config) string_task_name_list
)
for task_element in others_task_name_list:
if isinstance(task_element, dict):
task_name_from_config_dict = { task_name_from_config_dict = {
**task_name_from_config_dict, **task_name_from_config_dict,
get_task_name_from_config(task_element): ConfigurableTask( **task_manager.load_config(config=task_element),
config=task_element
),
} }
elif isinstance(task_element, Task): elif isinstance(task_element, Task):
...@@ -290,11 +442,12 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs): ...@@ -290,11 +442,12 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
get_task_name_from_object(task_element): task_element, get_task_name_from_object(task_element): task_element,
} }
assert set(task_name_from_registry_dict.keys()).isdisjoint( assert set(task_name_from_string_dict.keys()).isdisjoint(
set(task_name_from_object_dict.keys()) set(task_name_from_object_dict.keys())
) )
return { return {
**task_name_from_registry_dict, **task_name_from_string_dict,
**task_name_from_config_dict, **task_name_from_config_dict,
**task_name_from_object_dict, **task_name_from_object_dict,
} }
# ArabicMMLU
### Paper
ArabicMMLU: Measuring massive multitask language understanding in Arabic
This dataset has been translated from the original MMLU with the help of GPT-4.
The original data [MMLU](https://arxiv.org/pdf/2009.03300v3.pdf)
The translation has been done with AceGPT researchers [AceGPT](https://arxiv.org/abs/2309.12053)
ArabicMMLU is a comprehensive evaluation benchmark specifically designed to evaluate the knowledge and reasoning abilities of LLMs within the context of Arabic language and culture.
ArabicMMLU covers a wide range of subjects, comprising 57 topics that span from elementary to advanced professional levels.
Homepage: [AceGPT Homepage](https://github.com/FreedomIntelligence/AceGPT/tree/main/eval/benchmark_eval/benchmarks/MMLUArabic)
### Citation
### Groups and Tasks
#### Groups
- `ammlu`: All 57 subjects of the ArabicMMLU dataset, evaluated following the methodology in MMLU's original implementation.
#### Tasks
The following tasks evaluate subjects in the ArabicMMLU dataset using loglikelihood-based multiple-choice scoring:
- `ammlu_{subject_english}`
### Checklist
* [x] Is the task an existing benchmark in the literature?
* [x] Have you referenced the original paper that introduced the task?
* [x] If yes, does the original paper provide a reference implementation?
* [x] Yes, original implementation contributed by author of the benchmark
If other tasks on this dataset are already supported:
* [x] Is the "Main" variant of this task clearly denoted?
* [x] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [x] Have you noted which, if any, published evaluation setups are matched by this variant?
group: ammlu
dataset_path: Hennara/ammlu
test_split: test
fewshot_split: dev
fewshot_config:
sampler: first_n
output_type: multiple_choice
doc_to_text: "{{Question.strip()}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nالجواب:"
doc_to_choice: ["A", "B", "C", "D"]
doc_to_target: "{{['A', 'B', 'C', 'D'].index(Answer)}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
metadata:
version: 0.0
"""
Take in a YAML, and output all other splits with this YAML
"""
import argparse
import os
import yaml
from tqdm import tqdm
SUBJECTS = {
"abstract_algebra": "ألعلوم وتقنية المعلومات و الرياضيات",
"anatomy": "ألعلوم وتقنية المعلومات و الرياضيات",
"astronomy": "ألعلوم وتقنية المعلومات و الرياضيات",
"business_ethics": "علوم أخرى",
"clinical_knowledge": "علوم أخرى",
"college_biology": "ألعلوم وتقنية المعلومات و الرياضيات",
"college_chemistry": "ألعلوم وتقنية المعلومات و الرياضيات",
"college_computer_science": "ألعلوم وتقنية المعلومات و الرياضيات",
"college_mathematics": "ألعلوم وتقنية المعلومات و الرياضيات",
"college_medicine": "علوم أخرى",
"college_physics": "ألعلوم وتقنية المعلومات و الرياضيات",
"computer_security": "ألعلوم وتقنية المعلومات و الرياضيات",
"conceptual_physics": "ألعلوم وتقنية المعلومات و الرياضيات",
"econometrics": "العلوم الإجتماعية",
"electrical_engineering": "ألعلوم وتقنية المعلومات و الرياضيات",
"elementary_mathematics": "ألعلوم وتقنية المعلومات و الرياضيات",
"formal_logic": "العلوم الانسانية",
"global_facts": "علوم أخرى",
"high_school_biology": "ألعلوم وتقنية المعلومات و الرياضيات",
"high_school_chemistry": "ألعلوم وتقنية المعلومات و الرياضيات",
"high_school_computer_science": "ألعلوم وتقنية المعلومات و الرياضيات",
"high_school_european_history": "العلوم الانسانية",
"high_school_geography": "العلوم الإجتماعية",
"high_school_government_and_politics": "العلوم الإجتماعية",
"high_school_macroeconomics": "العلوم الإجتماعية",
"high_school_mathematics": "ألعلوم وتقنية المعلومات و الرياضيات",
"high_school_microeconomics": "العلوم الإجتماعية",
"high_school_physics": "ألعلوم وتقنية المعلومات و الرياضيات",
"high_school_psychology": "العلوم الإجتماعية",
"high_school_statistics": "ألعلوم وتقنية المعلومات و الرياضيات",
"high_school_us_history": "العلوم الانسانية",
"high_school_world_history": "العلوم الانسانية",
"human_aging": "علوم أخرى",
"human_sexuality": "العلوم الإجتماعية",
"international_law": "العلوم الانسانية",
"jurisprudence": "العلوم الانسانية",
"logical_fallacies": "العلوم الانسانية",
"machine_learning": "ألعلوم وتقنية المعلومات و الرياضيات",
"management": "علوم أخرى",
"marketing": "علوم أخرى",
"medical_genetics": "علوم أخرى",
"miscellaneous": "علوم أخرى",
"moral_disputes": "العلوم الانسانية",
"moral_scenarios": "العلوم الانسانية",
"nutrition": "علوم أخرى",
"philosophy": "العلوم الانسانية",
"prehistory": "العلوم الانسانية",
"professional_accounting": "علوم أخرى",
"professional_law": "العلوم الانسانية",
"professional_medicine": "علوم أخرى",
"professional_psychology": "العلوم الإجتماعية",
"public_relations": "العلوم الإجتماعية",
"security_studies": "العلوم الإجتماعية",
"sociology": "العلوم الإجتماعية",
"us_foreign_policy": "العلوم الإجتماعية",
"virology": "علوم أخرى",
"world_religions": "العلوم الانسانية",
}
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base_yaml_path", required=True)
parser.add_argument("--save_prefix_path", default="ammlu")
parser.add_argument("--cot_prompt_path", default=None)
parser.add_argument("--task_prefix", default="")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
base_yaml_name = os.path.split(args.base_yaml_path)[-1]
with open(args.base_yaml_path, encoding="utf-8") as f:
base_yaml = yaml.full_load(f)
if args.cot_prompt_path is not None:
import json
with open(args.cot_prompt_path, encoding="utf-8") as f:
cot_file = json.load(f)
for subject_eng, category in tqdm(SUBJECTS.items()):
if args.cot_prompt_path is not None:
description = cot_file[subject_eng]
else:
description = f"فم بعملية التقييم في مجال {category} \n\n"
yaml_dict = {
"include": base_yaml_name,
"task": f"ammlu_{args.task_prefix}_{subject_eng}"
if args.task_prefix != ""
else f"ammlu_{subject_eng}",
"dataset_name": subject_eng,
"description": description,
}
file_save_path = args.save_prefix_path + f"_{subject_eng}.yaml"
print(f"Saving yaml for subset {subject_eng} to {file_save_path}")
with open(file_save_path, "w", encoding="utf-8") as yaml_file:
yaml.dump(
yaml_dict,
yaml_file,
width=float("inf"),
allow_unicode=True,
default_style='"',
)
"dataset_name": "abstract_algebra"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_abstract_algebra"
"dataset_name": "anatomy"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_anatomy"
"dataset_name": "astronomy"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_astronomy"
"dataset_name": "business_ethics"
"description": "فم بعملية التقييم في مجال علوم أخرى \n\n"
"include": "_default_template_yaml"
"task": "ammlu_business_ethics"
"dataset_name": "clinical_knowledge"
"description": "فم بعملية التقييم في مجال علوم أخرى \n\n"
"include": "_default_template_yaml"
"task": "ammlu_clinical_knowledge"
"dataset_name": "college_biology"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_college_biology"
"dataset_name": "college_chemistry"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_college_chemistry"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment