Unverified Commit 9822b06e authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge branch 'main' into weight_by_size

parents 51f27158 b177c82c
import copy import copy
import os import os
from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import List, Literal, Optional, Tuple, Union from typing import List, Literal, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import transformers import transformers
from accelerate import Accelerator, DistributedType, find_executable_batch_size from accelerate import (
Accelerator,
DistributedType,
InitProcessGroupKwargs,
find_executable_batch_size,
)
from packaging import version from packaging import version
from peft import PeftModel from peft import PeftModel
from peft import __version__ as PEFT_VERSION from peft import __version__ as PEFT_VERSION
...@@ -18,9 +24,15 @@ from transformers.models.auto.modeling_auto import ( ...@@ -18,9 +24,15 @@ from transformers.models.auto.modeling_auto import (
from lm_eval import utils from lm_eval import utils
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import LM from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import Collator, stop_sequences_criteria from lm_eval.models.utils import (
Collator,
clear_torch_cache,
get_dtype,
pad_and_concat,
stop_sequences_criteria,
)
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
...@@ -52,7 +64,7 @@ def _get_accelerate_args( ...@@ -52,7 +64,7 @@ def _get_accelerate_args(
@register_model("hf-auto", "hf", "huggingface") @register_model("hf-auto", "hf", "huggingface")
class HFLM(LM): class HFLM(TemplateLM):
""" """
An abstracted Huggingface model class. Enables usage with both models of An abstracted Huggingface model class. Enables usage with both models of
`transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes. `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.
...@@ -66,9 +78,8 @@ class HFLM(LM): ...@@ -66,9 +78,8 @@ class HFLM(LM):
def __init__( def __init__(
self, self,
pretrained: Optional[Union[str, transformers.PreTrainedModel]] = "gpt2", pretrained: Optional[Union[str, transformers.PreTrainedModel]] = "gpt2",
backend: Optional[ backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
Literal["default", "causal", "seq2seq"] # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
] = "default", # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
revision: Optional[str] = "main", revision: Optional[str] = "main",
subfolder: Optional[str] = None, subfolder: Optional[str] = None,
tokenizer: Optional[ tokenizer: Optional[
...@@ -79,6 +90,7 @@ class HFLM(LM): ...@@ -79,6 +90,7 @@ class HFLM(LM):
] ]
] = None, ] = None,
truncation: Optional[bool] = False, truncation: Optional[bool] = False,
logits_cache: bool = True,
max_length: Optional[int] = None, max_length: Optional[int] = None,
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto", dtype: Optional[Union[str, torch.dtype]] = "auto",
...@@ -86,6 +98,7 @@ class HFLM(LM): ...@@ -86,6 +98,7 @@ class HFLM(LM):
max_batch_size: Optional[int] = 64, max_batch_size: Optional[int] = 64,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True, use_fast_tokenizer: Optional[bool] = True,
add_bos_token: Optional[bool] = False,
# arguments used for splitting a model across GPUs naively. # arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`. # only used if `parallelize=True`.
parallelize: Optional[bool] = False, parallelize: Optional[bool] = False,
...@@ -108,8 +121,8 @@ class HFLM(LM): ...@@ -108,8 +121,8 @@ class HFLM(LM):
assert not parallelize, "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`" assert not parallelize, "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
self._model = pretrained self._model = pretrained
self._device = self._model.device self._device = self._model.device
self._config = self._model.config self._config = self._model.config
gpus = 0
if tokenizer: if tokenizer:
assert isinstance( assert isinstance(
...@@ -132,7 +145,8 @@ class HFLM(LM): ...@@ -132,7 +145,8 @@ class HFLM(LM):
assert isinstance(batch_size, (int, str)) assert isinstance(batch_size, (int, str))
gpus = torch.cuda.device_count() gpus = torch.cuda.device_count()
accelerator = Accelerator() accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
self.accelerator = accelerator self.accelerator = accelerator
...@@ -226,7 +240,7 @@ class HFLM(LM): ...@@ -226,7 +240,7 @@ class HFLM(LM):
) )
self.truncation = truncation self.truncation = truncation
self.logits_cache = logits_cache
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
# select (or create) a pad token to use # select (or create) a pad token to use
if self.tokenizer.pad_token: if self.tokenizer.pad_token:
...@@ -236,7 +250,7 @@ class HFLM(LM): ...@@ -236,7 +250,7 @@ class HFLM(LM):
elif self.tokenizer.eos_token: elif self.tokenizer.eos_token:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
else: else:
if self.config.model_type == "qwen": if getattr(self.config, "model_type", None) == "qwen":
# Qwen's trust_remote_code tokenizer does not allow for adding special tokens # Qwen's trust_remote_code tokenizer does not allow for adding special tokens
self.tokenizer.pad_token = "<|endoftext|>" self.tokenizer.pad_token = "<|endoftext|>"
elif ( elif (
...@@ -252,6 +266,14 @@ class HFLM(LM): ...@@ -252,6 +266,14 @@ class HFLM(LM):
else: else:
self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
# TODO: override this for Gemma
self.add_bos_token = add_bos_token
if getattr(self.config, "model_type", None) == "gemma":
self.add_bos_token = True
eval_logger.info(
f"Model type is '{self.config.model_type}', a BOS token will be used as Gemma underperforms without it."
)
self._max_length = max_length self._max_length = max_length
self.batch_schedule = 1 self.batch_schedule = 1
...@@ -372,7 +394,7 @@ class HFLM(LM): ...@@ -372,7 +394,7 @@ class HFLM(LM):
def _get_backend( def _get_backend(
self, self,
config: transformers.AutoConfig, config: Union[transformers.PretrainedConfig, transformers.AutoConfig],
backend: Optional[Literal["default", "causal", "seq2seq"]] = "default", backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
) -> None: ) -> None:
...@@ -496,13 +518,13 @@ class HFLM(LM): ...@@ -496,13 +518,13 @@ class HFLM(LM):
if transformers.__version__ >= "4.30.0": if transformers.__version__ >= "4.30.0":
if model_kwargs.get("load_in_4bit", None): if model_kwargs.get("load_in_4bit", None):
if model_kwargs.get("bnb_4bit_compute_dtype", None): if model_kwargs.get("bnb_4bit_compute_dtype", None):
model_kwargs["bnb_4bit_compute_dtype"] = utils.get_dtype( model_kwargs["bnb_4bit_compute_dtype"] = get_dtype(
model_kwargs["bnb_4bit_compute_dtype"] model_kwargs["bnb_4bit_compute_dtype"]
) )
self._model = self.AUTO_MODEL_CLASS.from_pretrained( self._model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
torch_dtype=utils.get_dtype(dtype), torch_dtype=get_dtype(dtype),
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
**model_kwargs, **model_kwargs,
) )
...@@ -617,7 +639,13 @@ class HFLM(LM): ...@@ -617,7 +639,13 @@ class HFLM(LM):
return batch_size return batch_size
batch_size = forward_batch() try:
batch_size = forward_batch()
except RuntimeError as e:
if "No executable batch size found" in str(e):
batch_size = 1
else:
raise
if self.world_size > 1: if self.world_size > 1:
# if multi-GPU, always take minimum over all selected batch sizes # if multi-GPU, always take minimum over all selected batch sizes
...@@ -626,10 +654,10 @@ class HFLM(LM): ...@@ -626,10 +654,10 @@ class HFLM(LM):
self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist() self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
) )
batch_size = min(gathered) batch_size = min(gathered)
utils.clear_torch_cache() clear_torch_cache()
return batch_size return batch_size
utils.clear_torch_cache() clear_torch_cache()
return batch_size return batch_size
def tok_encode( def tok_encode(
...@@ -638,8 +666,9 @@ class HFLM(LM): ...@@ -638,8 +666,9 @@ class HFLM(LM):
""" """ """ """
if add_special_tokens is None: if add_special_tokens is None:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False add_special_tokens = False or self.add_bos_token
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# TODO: investigate best practices for enc-dec models + special tokens
add_special_tokens = True add_special_tokens = True
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
...@@ -662,7 +691,7 @@ class HFLM(LM): ...@@ -662,7 +691,7 @@ class HFLM(LM):
self.tokenizer.padding_side = padding_side self.tokenizer.padding_side = padding_side
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False add_special_tokens = False or self.add_bos_token
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
add_special_tokens = True add_special_tokens = True
...@@ -721,6 +750,11 @@ class HFLM(LM): ...@@ -721,6 +750,11 @@ class HFLM(LM):
# and we don't want a warning from HF # and we don't want a warning from HF
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
do_sample = generation_kwargs.get("do_sample", None) do_sample = generation_kwargs.get("do_sample", None)
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
generation_kwargs["do_sample"] = do_sample = False
if do_sample is False and generation_kwargs.get("temperature") == 0.0: if do_sample is False and generation_kwargs.get("temperature") == 0.0:
generation_kwargs.pop("temperature") generation_kwargs.pop("temperature")
# build stopping criteria # build stopping criteria
...@@ -736,7 +770,9 @@ class HFLM(LM): ...@@ -736,7 +770,9 @@ class HFLM(LM):
**generation_kwargs, **generation_kwargs,
) )
def _select_cont_toks(self, logits, contlen=None, inplen=None): def _select_cont_toks(
self, logits: torch.Tensor, contlen: int = None, inplen: int = None
) -> torch.Tensor:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
assert ( assert (
contlen and inplen contlen and inplen
...@@ -754,39 +790,6 @@ class HFLM(LM): ...@@ -754,39 +790,6 @@ class HFLM(LM):
return logits return logits
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation, add_special_tokens=False)
context_enc = self.tok_encode(context, add_special_tokens=False)
# whole_enc = self.tok_encode(context + continuation)
# context_enc = self.tok_encode(context, add_special_tokens=False)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc, continuation_enc = (
[self.eot_token_id],
self.tok_encode(continuation),
)
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
loglikelihoods = [] loglikelihoods = []
...@@ -827,7 +830,7 @@ class HFLM(LM): ...@@ -827,7 +830,7 @@ class HFLM(LM):
rolling_token_windows += pad_amnt * [rolling_token_windows[0]] rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
string_nll = self._loglikelihood_tokens( string_nll = self._loglikelihood_tokens(
rolling_token_windows, requests=rolling_token_windows,
disable_tqdm=True, disable_tqdm=True,
override_bs=adaptive_batch_size, override_bs=adaptive_batch_size,
) )
...@@ -869,7 +872,7 @@ class HFLM(LM): ...@@ -869,7 +872,7 @@ class HFLM(LM):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = [] res = []
def _collate(x): def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key for the sorted method""" """Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages: # the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning # - time estimates will always be over not underestimates, which is more useful for planning
...@@ -878,10 +881,26 @@ class HFLM(LM): ...@@ -878,10 +881,26 @@ class HFLM(LM):
# automatic adaptive batches much much easier to implement # automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end # - any OOMs will happen right away rather than near the end
toks = x[1] + x[2] toks = req[1] + req[2]
return -len(toks), tuple(toks) return -len(toks), tuple(toks)
re_ord = Collator(requests, sort_fn=_collate) def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key to group and lookup one-token continuations"""
# Use with group_by="contexts" (optional)"
# allows for the creation of a lookup, so we can re-use logits in case of one-token continuations.
# speeds up some multiple-choice tasks proportionally to the number of choices.
# groups requests by context+continuation[:-1] and infer on one request/group.
return req[-2] + req[-1][:-1]
re_ord = Collator(
requests,
sort_fn=_collate,
group_by="contexts"
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
and self.logits_cache
else None,
group_fn=_lookup_one_token_cont,
)
# automatic (variable) batch size detection for vectorization # automatic (variable) batch size detection for vectorization
# pull longest context sample from request # pull longest context sample from request
...@@ -902,7 +921,11 @@ class HFLM(LM): ...@@ -902,7 +921,11 @@ class HFLM(LM):
) )
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0))) pbar = tqdm(
total=len(requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running loglikelihood requests",
)
for chunk in chunks: for chunk in chunks:
inps = [] inps = []
cont_toks_list = [] cont_toks_list = []
...@@ -979,18 +1002,18 @@ class HFLM(LM): ...@@ -979,18 +1002,18 @@ class HFLM(LM):
# create encoder attn mask and batched conts, if seq2seq # create encoder attn mask and batched conts, if seq2seq
call_kwargs = {} call_kwargs = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
batched_inps = utils.pad_and_concat( batched_inps = pad_and_concat(
padding_len_inp, inps, padding_side="right" padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp] ) # [batch, padding_len_inp]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# TODO: left-pad encoder inps and mask? # TODO: left-pad encoder inps and mask?
batched_inps = utils.pad_and_concat( batched_inps = pad_and_concat(
padding_len_inp, inps padding_len_inp, inps
) # [batch, padding_len_inp] ) # [batch, padding_len_inp]
batched_conts = utils.pad_and_concat( batched_conts = pad_and_concat(
padding_len_cont, conts padding_len_cont, conts
) # [batch, padding_len_cont] ) # [batch, padding_len_cont]
batched_encoder_mask = utils.pad_and_concat( batched_encoder_mask = pad_and_concat(
padding_len_inp, encoder_attns padding_len_inp, encoder_attns
) # [batch, padding_len_inp] ) # [batch, padding_len_inp]
call_kwargs = { call_kwargs = {
...@@ -1002,7 +1025,7 @@ class HFLM(LM): ...@@ -1002,7 +1025,7 @@ class HFLM(LM):
self._model_call(batched_inps, **call_kwargs), dim=-1 self._model_call(batched_inps, **call_kwargs), dim=-1
) # [batch, padding_length (inp or cont), vocab] ) # [batch, padding_length (inp or cont), vocab]
for (cache_key, _, _), logits, inplen, cont_toks in zip( for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
chunk, multi_logits, inplens, cont_toks_list chunk, multi_logits, inplens, cont_toks_list
): ):
# Slice to original seq length # Slice to original seq length
...@@ -1021,24 +1044,36 @@ class HFLM(LM): ...@@ -1021,24 +1044,36 @@ class HFLM(LM):
# Check if per-token argmax is exactly equal to continuation # Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1) greedy_tokens = logits.argmax(dim=-1)
cont_toks = torch.tensor(
cont_toks, dtype=torch.long, device=self.device
).unsqueeze(0) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
-1
) # [1, seq]
# Answer: (log prob, is-exact-match) # check for one-token continuation cache hits.
answer = (float(logits.sum()), bool(max_equal)) # noop in case group_by != "contexts" or no cache hit and returns the
# original args. Otherwise, expands the logits batch dimension and yields each
res.append(answer) # batch along with matching continuation tokens and prompt strings.
# logits -> [1, seq, vocab]
self.cache_hook.add_partial("loglikelihood", cache_key, answer) for request_str, cont_toks, logits in re_ord.get_cache(
pbar.update(1) req_str=request_str,
cxt_toks=ctx_tokens,
cont_toks=cont_toks,
logits=logits,
):
cont_toks = torch.tensor(
cont_toks, dtype=torch.long, device=self.device
).unsqueeze(0) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
-1
) # [1, seq]
# Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal))
res.append(answer)
self.cache_hook.add_partial("loglikelihood", request_str, answer)
pbar.update(1)
pbar.close() pbar.close()
...@@ -1047,7 +1082,7 @@ class HFLM(LM): ...@@ -1047,7 +1082,7 @@ class HFLM(LM):
def generate_until(self, requests: List[Instance]) -> List[str]: def generate_until(self, requests: List[Instance]) -> List[str]:
res = [] res = []
def _collate(x): def _collate(req: Tuple[str, dict]):
"""Defines the key for the sorted method""" """Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages: # the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning # - time estimates will always be over not underestimates, which is more useful for planning
...@@ -1055,10 +1090,15 @@ class HFLM(LM): ...@@ -1055,10 +1090,15 @@ class HFLM(LM):
# padded context length. this is useful to simplify the batching logic and more importantly to make # padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement # automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end # - any OOMs will happen right away rather than near the end
toks = self.tok_encode(x[0]) toks = self.tok_encode(req[0])
return -len(toks), x[0] return -len(toks), req[0]
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(
total=len(requests),
disable=(self.rank != 0),
desc="Running generate_until requests",
)
adaptive_batch_size = None
if self.batch_size == "auto": if self.batch_size == "auto":
# using rolling window with maximum context # using rolling window with maximum context
print("Passed argument batch_size = auto. Detecting largest batch size") print("Passed argument batch_size = auto. Detecting largest batch size")
...@@ -1082,7 +1122,13 @@ class HFLM(LM): ...@@ -1082,7 +1122,13 @@ class HFLM(LM):
# we group requests by their generation_kwargs, # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch. # in the same batch.
re_ords = Collator([reg.args for reg in requests], _collate, grouping=True) # group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
re_ords = Collator(
[reg.args for reg in requests],
sort_fn=_collate,
group_by="gen_kwargs",
group_fn=lambda x: x[1],
)
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
for chunk in chunks: for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk) contexts, all_gen_kwargs = zip(*chunk)
...@@ -1103,7 +1149,7 @@ class HFLM(LM): ...@@ -1103,7 +1149,7 @@ class HFLM(LM):
) )
else: else:
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
if not until: if not until:
until = [self.tok_decode(self.eot_token_id)] until = [self.tok_decode(self.eot_token_id)]
......
...@@ -2,7 +2,7 @@ from typing import Optional, Union ...@@ -2,7 +2,7 @@ from typing import Optional, Union
import torch import torch
from lm_eval import utils import lm_eval.models.utils
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM from lm_eval.models.huggingface import HFLM
...@@ -56,9 +56,9 @@ class MambaLMWrapper(HFLM): ...@@ -56,9 +56,9 @@ class MambaLMWrapper(HFLM):
super().__init__( super().__init__(
pretrained=pretrained, pretrained=pretrained,
# set appropriate defaults for tokenizer, max length, etc # set appropriate defaults for tokenizer, max length, etc
backend=kwargs.get("backend", "causal"), backend=kwargs.pop("backend", "causal"),
tokenizer=kwargs.get("tokenizer", "EleutherAI/gpt-neox-20b"), tokenizer=kwargs.pop("tokenizer", "EleutherAI/gpt-neox-20b"),
max_length=kwargs.get("max_length", 2048), max_length=kwargs.pop("max_length", 2048),
**kwargs, **kwargs,
) )
...@@ -97,7 +97,9 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba ...@@ -97,7 +97,9 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba
self._model = MambaLMHeadModel.from_pretrained( self._model = MambaLMHeadModel.from_pretrained(
pretrained, pretrained,
device=self._device, device=self._device,
dtype=torch.float16 if dtype == "auto" else utils.get_dtype(dtype), dtype=torch.float16
if dtype == "auto"
else lm_eval.models.utils.get_dtype(dtype),
) )
def _model_generate(self, context, max_length, stop, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
......
This diff is collapsed.
...@@ -6,10 +6,12 @@ from typing import List, Literal, Optional, Tuple ...@@ -6,10 +6,12 @@ from typing import List, Literal, Optional, Tuple
from tqdm import tqdm from tqdm import tqdm
import lm_eval.models.utils
from lm_eval import utils from lm_eval import utils
from lm_eval.api.model import LM from lm_eval.api.model import LM, TemplateLM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import eval_logger, retry_on_specific_exceptions from lm_eval.models.utils import retry_on_specific_exceptions
from lm_eval.utils import eval_logger
def get_result(response, ctxlen: int) -> Tuple[float, bool]: def get_result(response, ctxlen: int) -> Tuple[float, bool]:
...@@ -73,7 +75,7 @@ def oa_completion(client, chat: bool = False, **kwargs): ...@@ -73,7 +75,7 @@ def oa_completion(client, chat: bool = False, **kwargs):
@register_model("openai-completions", "local-completions") @register_model("openai-completions", "local-completions")
class OpenaiCompletionsLM(LM): class OpenaiCompletionsLM(TemplateLM):
_DEFAULT_MAX_LENGTH = 2048 _DEFAULT_MAX_LENGTH = 2048
def __init__( def __init__(
...@@ -169,41 +171,12 @@ class OpenaiCompletionsLM(LM): ...@@ -169,41 +171,12 @@ class OpenaiCompletionsLM(LM):
# Isn't used because we override _loglikelihood_tokens # Isn't used because we override _loglikelihood_tokens
raise NotImplementedError() raise NotImplementedError()
def tok_encode(self, string: str) -> List[int]: def tok_encode(self, string: str, **kwargs) -> List[int]:
return self.tokenizer.encode(string) return self.tokenizer.encode(string)
def tok_decode(self, tokens: List[int]) -> str: def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens) return self.tokenizer.decode(tokens)
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc, continuation_enc = (
[self.eot_token_id],
self.tok_encode(continuation),
)
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]: ) -> List[Tuple[float, bool]]:
...@@ -219,7 +192,7 @@ class OpenaiCompletionsLM(LM): ...@@ -219,7 +192,7 @@ class OpenaiCompletionsLM(LM):
re_ord = utils.Reorderer(requests, _collate) re_ord = utils.Reorderer(requests, _collate)
for chunk in tqdm( for chunk in tqdm(
list(utils.chunks(re_ord.get_reordered(), self.batch_size)), list(lm_eval.models.utils.chunks(re_ord.get_reordered(), self.batch_size)),
disable=disable_tqdm, disable=disable_tqdm,
): ):
inps = [] inps = []
...@@ -288,14 +261,13 @@ class OpenaiCompletionsLM(LM): ...@@ -288,14 +261,13 @@ class OpenaiCompletionsLM(LM):
list(sameuntil_chunks(re_ord.get_reordered(), self.batch_size)) list(sameuntil_chunks(re_ord.get_reordered(), self.batch_size))
): ):
inps = [] inps = []
self._max_gen_toks = request_args.pop("max_gen_toks", self.max_gen_toks) self._max_gen_toks = request_args.get("max_gen_toks", self.max_gen_toks)
for context, _ in chunk: for context, _ in chunk:
context_enc = self.tok_encode(context) context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks) :] inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp) inps.append(inp)
until = request_args.pop("until", ["<|endoftext|>"]) until = request_args.get("until", ["<|endoftext|>"])
request_args.pop("do_sample", None)
request_args["temperature"] = request_args.get("temperature", 0) request_args["temperature"] = request_args.get("temperature", 0)
response = oa_completion( response = oa_completion(
...@@ -305,7 +277,11 @@ class OpenaiCompletionsLM(LM): ...@@ -305,7 +277,11 @@ class OpenaiCompletionsLM(LM):
max_tokens=self.max_gen_toks, max_tokens=self.max_gen_toks,
stop=until, stop=until,
seed=self.seed, seed=self.seed,
**request_args, **{
k: v
for k, v in request_args.items()
if k not in ["do_sample", "max_gen_toks"]
},
) )
for resp, (context, args_) in zip(response.choices, chunk): for resp, (context, args_) in zip(response.choices, chunk):
s = getattr(resp, "text") s = getattr(resp, "text")
...@@ -429,7 +405,7 @@ class OpenaiChatCompletionsLM(LM): ...@@ -429,7 +405,7 @@ class OpenaiChatCompletionsLM(LM):
# we group requests by their generation_kwargs, # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch. # in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x.args[1])) grouper = lm_eval.models.utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items(): for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending. # within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer( re_ords[key] = utils.Reorderer(
...@@ -441,7 +417,7 @@ class OpenaiChatCompletionsLM(LM): ...@@ -441,7 +417,7 @@ class OpenaiChatCompletionsLM(LM):
# n needs to be 1 because messages in # n needs to be 1 because messages in
# chat completion are not batch but # chat completion are not batch but
# is regarded as a single conversation. # is regarded as a single conversation.
chunks = utils.chunks(re_ord.get_reordered(), n=1) chunks = lm_eval.models.utils.chunks(re_ord.get_reordered(), n=1)
for chunk in chunks: for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk) contexts, all_gen_kwargs = zip(*chunk)
inps = [{"role": "user", "content": context} for context in contexts] inps = [{"role": "user", "content": context} for context in contexts]
......
...@@ -28,7 +28,7 @@ class OptimumLM(HFLM): ...@@ -28,7 +28,7 @@ class OptimumLM(HFLM):
super().__init__( super().__init__(
device=self.openvino_device, device=self.openvino_device,
backend=kwargs.get("backend", "causal"), backend=kwargs.pop("backend", "causal"),
**kwargs, **kwargs,
) )
......
...@@ -19,7 +19,7 @@ from tqdm import tqdm ...@@ -19,7 +19,7 @@ from tqdm import tqdm
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import retry_on_specific_exceptions from lm_eval.models.utils import retry_on_specific_exceptions
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
This diff is collapsed.
This diff is collapsed.
import os
import ast import ast
import os
from typing import Dict from typing import Dict
from lm_eval import utils from lm_eval import utils
from lm_eval.utils import eval_logger from lm_eval.utils import eval_logger
# Prompt library. # Prompt library.
# Stores prompts in a dictionary indexed by 2 levels: # Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name. # prompt category name, and prompt name.
......
This diff is collapsed.
This diff is collapsed.
group: ammlu
dataset_path: Hennara/ammlu
test_split: test
fewshot_split: dev
fewshot_config:
sampler: first_n
output_type: multiple_choice
doc_to_text: "{{Question.strip()}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nالجواب:"
doc_to_choice: ["A", "B", "C", "D"]
doc_to_target: "{{['A', 'B', 'C', 'D'].index(Answer)}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
metadata:
version: 0.0
This diff is collapsed.
"dataset_name": "abstract_algebra"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_abstract_algebra"
"dataset_name": "anatomy"
"description": "فم بعملية التقييم في مجال ألعلوم وتقنية المعلومات و الرياضيات \n\n"
"include": "_default_template_yaml"
"task": "ammlu_anatomy"
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment