Unverified Commit 4e1ef749 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Update huggingface.py

parent ae79b121
......@@ -24,7 +24,7 @@ from transformers.models.auto.modeling_auto import (
from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import (
Collator,
......@@ -64,7 +64,7 @@ def _get_accelerate_args(
@register_model("hf-auto", "hf", "huggingface")
class HFLM(TemplateLM):
class HFLM(LM):
"""
An abstracted Huggingface model class. Enables usage with both models of
`transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.
......@@ -78,8 +78,9 @@ class HFLM(TemplateLM):
def __init__(
self,
pretrained: Optional[Union[str, transformers.PreTrainedModel]] = "gpt2",
backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
# override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
backend: Optional[
Literal["default", "causal", "seq2seq"]
] = "default", # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
revision: Optional[str] = "main",
subfolder: Optional[str] = None,
tokenizer: Optional[
......@@ -90,7 +91,6 @@ class HFLM(TemplateLM):
]
] = None,
truncation: Optional[bool] = False,
logits_cache: bool = True,
max_length: Optional[int] = None,
device: Optional[str] = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto",
......@@ -98,7 +98,6 @@ class HFLM(TemplateLM):
max_batch_size: Optional[int] = 64,
trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True,
add_bos_token: Optional[bool] = False,
# arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`.
parallelize: Optional[bool] = False,
......@@ -240,7 +239,7 @@ class HFLM(TemplateLM):
)
self.truncation = truncation
self.logits_cache = logits_cache
self.vocab_size = self.tokenizer.vocab_size
# select (or create) a pad token to use
if self.tokenizer.pad_token:
......@@ -250,7 +249,7 @@ class HFLM(TemplateLM):
elif self.tokenizer.eos_token:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
else:
if getattr(self.config, "model_type", None) == "qwen":
if self.config.model_type == "qwen":
# Qwen's trust_remote_code tokenizer does not allow for adding special tokens
self.tokenizer.pad_token = "<|endoftext|>"
elif (
......@@ -266,14 +265,6 @@ class HFLM(TemplateLM):
else:
self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
# TODO: override this for Gemma
self.add_bos_token = add_bos_token
if getattr(self.config, "model_type", None) == "gemma":
self.add_bos_token = True
eval_logger.info(
f"Model type is '{self.config.model_type}', a BOS token will be used as Gemma underperforms without it."
)
self._max_length = max_length
self.batch_schedule = 1
......@@ -666,9 +657,8 @@ class HFLM(TemplateLM):
""" """
if add_special_tokens is None:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False or self.add_bos_token
add_special_tokens = False
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# TODO: investigate best practices for enc-dec models + special tokens
add_special_tokens = True
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
......@@ -691,7 +681,7 @@ class HFLM(TemplateLM):
self.tokenizer.padding_side = padding_side
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False or self.add_bos_token
add_special_tokens = False
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
add_special_tokens = True
......@@ -770,9 +760,7 @@ class HFLM(TemplateLM):
**generation_kwargs,
)
def _select_cont_toks(
self, logits: torch.Tensor, contlen: int = None, inplen: int = None
) -> torch.Tensor:
def _select_cont_toks(self, logits, contlen=None, inplen=None):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
assert (
contlen and inplen
......@@ -790,6 +778,39 @@ class HFLM(TemplateLM):
return logits
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation, add_special_tokens=False)
context_enc = self.tok_encode(context, add_special_tokens=False)
# whole_enc = self.tok_encode(context + continuation)
# context_enc = self.tok_encode(context, add_special_tokens=False)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc, continuation_enc = (
[self.eot_token_id],
self.tok_encode(continuation),
)
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
loglikelihoods = []
......@@ -830,7 +851,7 @@ class HFLM(TemplateLM):
rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
string_nll = self._loglikelihood_tokens(
requests=rolling_token_windows,
rolling_token_windows,
disable_tqdm=True,
override_bs=adaptive_batch_size,
)
......@@ -872,7 +893,7 @@ class HFLM(TemplateLM):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
def _collate(x):
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
......@@ -881,26 +902,10 @@ class HFLM(TemplateLM):
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = req[1] + req[2]
toks = x[1] + x[2]
return -len(toks), tuple(toks)
def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key to group and lookup one-token continuations"""
# Use with group_by="contexts" (optional)"
# allows for the creation of a lookup, so we can re-use logits in case of one-token continuations.
# speeds up some multiple-choice tasks proportionally to the number of choices.
# groups requests by context+continuation[:-1] and infer on one request/group.
return req[-2] + req[-1][:-1]
re_ord = Collator(
requests,
sort_fn=_collate,
group_by="contexts"
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
and self.logits_cache
else None,
group_fn=_lookup_one_token_cont,
)
re_ord = Collator(requests, sort_fn=_collate)
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
......@@ -921,11 +926,7 @@ class HFLM(TemplateLM):
)
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
pbar = tqdm(
total=len(requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running loglikelihood requests",
)
pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0)))
for chunk in chunks:
inps = []
cont_toks_list = []
......@@ -1025,7 +1026,7 @@ class HFLM(TemplateLM):
self._model_call(batched_inps, **call_kwargs), dim=-1
) # [batch, padding_length (inp or cont), vocab]
for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
for (cache_key, _, _), logits, inplen, cont_toks in zip(
chunk, multi_logits, inplens, cont_toks_list
):
# Slice to original seq length
......@@ -1044,36 +1045,24 @@ class HFLM(TemplateLM):
# Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1)
cont_toks = torch.tensor(
cont_toks, dtype=torch.long, device=self.device
).unsqueeze(0) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
-1
) # [1, seq]
# check for one-token continuation cache hits.
# noop in case group_by != "contexts" or no cache hit and returns the
# original args. Otherwise, expands the logits batch dimension and yields each
# batch along with matching continuation tokens and prompt strings.
# logits -> [1, seq, vocab]
for request_str, cont_toks, logits in re_ord.get_cache(
req_str=request_str,
cxt_toks=ctx_tokens,
cont_toks=cont_toks,
logits=logits,
):
cont_toks = torch.tensor(
cont_toks, dtype=torch.long, device=self.device
).unsqueeze(0) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
-1
) # [1, seq]
# Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal))
res.append(answer)
self.cache_hook.add_partial("loglikelihood", request_str, answer)
pbar.update(1)
# Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal))
res.append(answer)
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
pbar.update(1)
pbar.close()
......@@ -1082,7 +1071,7 @@ class HFLM(TemplateLM):
def generate_until(self, requests: List[Instance]) -> List[str]:
res = []
def _collate(req: Tuple[str, dict]):
def _collate(x):
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
......@@ -1090,14 +1079,10 @@ class HFLM(TemplateLM):
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(req[0])
return -len(toks), req[0]
toks = self.tok_encode(x[0])
return -len(toks), x[0]
pbar = tqdm(
total=len(requests),
disable=(self.rank != 0),
desc="Running generate_until requests",
)
pbar = tqdm(total=len(requests), disable=(self.rank != 0))
adaptive_batch_size = None
if self.batch_size == "auto":
# using rolling window with maximum context
......@@ -1122,13 +1107,7 @@ class HFLM(TemplateLM):
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
# group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
re_ords = Collator(
[reg.args for reg in requests],
sort_fn=_collate,
group_by="gen_kwargs",
group_fn=lambda x: x[1],
)
re_ords = Collator([reg.args for reg in requests], _collate, grouping=True)
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk)
......@@ -1151,12 +1130,8 @@ class HFLM(TemplateLM):
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id)
if not until:
until = [eos]
else:
until.append(eos)
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment