Commit ffc3a456 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

push WIP batched multi-kwarg code

parent 83f957bc
......@@ -3,6 +3,7 @@ import transformers
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
import copy
from collections import defaultdict
from tqdm import tqdm
import torch.nn.functional as F
......@@ -520,87 +521,108 @@ class HFLM(LM):
return re_ord.get_original(res)
def greedy_until(self, requests):
res = []
res = defaultdict(list)
re_ords = {}
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(x[0])
return -len(toks), x[0]
re_ord = utils.Reorderer([req.args for req in requests], _collate)
grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
for chunk in utils.chunks(
tqdm(
pbar = tqdm(total=len(requests))
assert len(requests) == sum(
[len(list(re_ord.get_reordered())) for re_ord in re_ords.values()]
)
for key, re_ord in re_ords.items():
for chunk in utils.chunks(
# tqdm(
re_ord.get_reordered(),
disable=(self.rank != 0),
),
self.batch_size,
):
contexts, all_gen_kwargs = zip(*chunk)
gen_kwargs = all_gen_kwargs[
0
] # TODO: handle case where not all gen kwargs are same
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `generation_kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `generation_kwargs` to be of type `dict` but got {kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# first stop sequence is used to halt generation upon encountering
(primary_until) = until[0]
# set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length
# encode, pad, and truncate contexts
context_enc, attn_masks = self.tok_batch_encode(
contexts, left_truncate_len=max_ctx_len
)
context_enc = context_enc.to(self.device)
attn_masks = attn_masks.to(self.device)
cont = self._model_generate(
context=context_enc,
attention_mask=attn_masks,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**kwargs,
)
cont_toks_list = cont.tolist()
for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only LM
# disable=(self.rank != 0),
# ),
self.batch_size,
):
contexts, all_gen_kwargs = zip(*chunk)
gen_kwargs = all_gen_kwargs[
0
] # TODO: handle case where not all gen kwargs are same
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `generation_kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `generation_kwargs` to be of type `dict` but got {kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# first stop sequence is used to halt generation upon encountering
(primary_until) = until[0]
# set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
cont_toks = cont_toks[context_enc.shape[1] :]
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length
s = self.tok_decode(cont_toks)
# encode, pad, and truncate contexts
context_enc, attn_masks = self.tok_batch_encode(
contexts, left_truncate_len=max_ctx_len
)
context_enc = context_enc.to(self.device)
attn_masks = attn_masks.to(self.device)
cont = self._model_generate(
context=context_enc,
attention_mask=attn_masks,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**kwargs,
)
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
if len(term) > 0: # ignore '' separator, for seq2seq case where
s = s.split(term)[0]
cont_toks_list = cont.tolist()
for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only LM
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
cont_toks = cont_toks[context_enc.shape[1] :]
res.append(s)
s = self.tok_decode(cont_toks)
self.cache_hook.add_partial("greedy_until", (context, gen_kwargs), s)
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
if len(term) > 0: # ignore '' separator, for seq2seq case where
s = s.split(term)[0]
return re_ord.get_original(res)
res[str(gen_kwargs)].append(
s
) # TODO: move this to res[-1].append(s) to separate per re_ord
self.cache_hook.add_partial(
"greedy_until", (context, gen_kwargs), s
)
pbar.update(1)
res[key] = re_ord.get_original(res[key])
pbar.close()
return grouper.get_original(res)
# return utils.join_iters([re_ord.get_original(rs) for re_ord, rs in zip(re_ords, res.values())])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment