Commit dced6e96 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add batched generation

parent e3960fa0
......@@ -15,6 +15,7 @@ from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator
from typing import List, Union
@register_model("hf-auto", "hf", "huggingface")
......@@ -99,6 +100,7 @@ class HFLM(LM):
)
self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self._max_length = max_length
......@@ -204,6 +206,33 @@ class HFLM(LM):
return encoding
def tok_batch_encode(
self, strings: List[str], padding_side="left", left_truncate_len=None
):
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
add_special_tokens = True
encoding = self.tokenizer(
strings,
padding="longest",
return_tensors="pt",
add_special_tokens=add_special_tokens,
)
if left_truncate_len:
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
encoding["attention_mask"] = encoding["attention_mask"][
:, -left_truncate_len:
]
self.tokenizer.padding_side = old_padding_side
return encoding["input_ids"], encoding["attention_mask"]
def tok_decode(self, tokens):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
return self.tokenizer.decode(tokens)
......@@ -495,13 +524,21 @@ class HFLM(LM):
def _collate(x):
toks = self.tok_encode(x[0])
return len(toks), x[0]
return -len(toks), x[0]
re_ord = utils.Reorderer([req.args for req in requests], _collate)
for context, gen_kwargs in tqdm(
re_ord.get_reordered(), disable=(self.rank != 0)
for chunk in utils.chunks(
tqdm(
re_ord.get_reordered(),
disable=(self.rank != 0),
),
self.batch_size,
):
contexts, all_gen_kwargs = zip(*chunk)
gen_kwargs = all_gen_kwargs[
0
] # TODO: handle case where not all gen kwargs are same
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
......@@ -534,32 +571,42 @@ class HFLM(LM):
# max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length
context_enc = torch.tensor(
[self.tok_encode(context, left_truncate_len=max_ctx_len)],
device=self.device,
context_enc, attn_masks = self.tok_batch_encode(
contexts, left_truncate_len=max_ctx_len
)
context_enc = context_enc.to(self.device)
attn_masks = attn_masks.to(self.device)
# [self.tok_encode(context, left_truncate_len=max_ctx_len)],
# device=self.device,
# ) for context in contexts]
# padding_len = max([context.shape[1] for context in context_enc])
# self.tokenizer.batch_encod
# context_enc = utils.pad_and_concat(padding_len, context_enc, padding_side="left")
cont = self._model_generate(
context=context_enc,
attention_mask=attn_masks,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**kwargs,
)
cont_toks_list = cont[0].tolist()
# discard context toks if using causal decoder-only LM
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
cont_toks_list = cont_toks_list[context_enc.shape[1] :]
cont_toks_list = cont.tolist()
for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only LM
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
cont_toks = cont_toks[context_enc.shape[1] :]
s = self.tok_decode(cont_toks_list)
s = self.tok_decode(cont_toks)
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
if len(term) > 0: # ignore '' separator, for seq2seq case where
s = s.split(term)[0]
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
if len(term) > 0: # ignore '' separator, for seq2seq case where
s = s.split(term)[0]
res.append(s)
res.append(s)
self.cache_hook.add_partial("greedy_until", (context, gen_kwargs), s)
self.cache_hook.add_partial("greedy_until", (context, gen_kwargs), s)
return re_ord.get_original(res)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment