Commit ffc3a456 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

push WIP batched multi-kwarg code

parent 83f957bc
...@@ -3,6 +3,7 @@ import transformers ...@@ -3,6 +3,7 @@ import transformers
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
import copy import copy
from collections import defaultdict
from tqdm import tqdm from tqdm import tqdm
import torch.nn.functional as F import torch.nn.functional as F
...@@ -520,87 +521,108 @@ class HFLM(LM): ...@@ -520,87 +521,108 @@ class HFLM(LM):
return re_ord.get_original(res) return re_ord.get_original(res)
def greedy_until(self, requests): def greedy_until(self, requests):
res = [] res = defaultdict(list)
re_ords = {}
def _collate(x): def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(x[0]) toks = self.tok_encode(x[0])
return -len(toks), x[0] return -len(toks), x[0]
re_ord = utils.Reorderer([req.args for req in requests], _collate) grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
for chunk in utils.chunks( pbar = tqdm(total=len(requests))
tqdm( assert len(requests) == sum(
[len(list(re_ord.get_reordered())) for re_ord in re_ords.values()]
)
for key, re_ord in re_ords.items():
for chunk in utils.chunks(
# tqdm(
re_ord.get_reordered(), re_ord.get_reordered(),
disable=(self.rank != 0), # disable=(self.rank != 0),
), # ),
self.batch_size, self.batch_size,
): ):
contexts, all_gen_kwargs = zip(*chunk) contexts, all_gen_kwargs = zip(*chunk)
gen_kwargs = all_gen_kwargs[ gen_kwargs = all_gen_kwargs[
0 0
] # TODO: handle case where not all gen kwargs are same ] # TODO: handle case where not all gen kwargs are same
until = None until = None
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys(): if "until" in kwargs.keys():
until = kwargs.pop("until") until = kwargs.pop("until")
if isinstance(until, str): if isinstance(until, str):
until = [kwargs] until = [kwargs]
elif not isinstance(until, list): elif not isinstance(until, list):
raise ValueError( raise ValueError(
f"Expected `generation_kwargs['until']` to be of type Union[str,list] but got {until}" f"Expected `generation_kwargs['until']` to be of type Union[str,list] but got {until}"
) )
else: else:
raise ValueError( raise ValueError(
f"Expected `generation_kwargs` to be of type `dict` but got {kwargs}" f"Expected `generation_kwargs` to be of type `dict` but got {kwargs}"
) )
if not until: if not until:
until = [self.tok_decode(self.eot_token_id)] until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in kwargs.keys(): if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
else: else:
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
# first stop sequence is used to halt generation upon encountering # first stop sequence is used to halt generation upon encountering
(primary_until) = until[0] (primary_until) = until[0]
# set the max length in tokens of inputs ("context_enc") # set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length
# encode, pad, and truncate contexts
context_enc, attn_masks = self.tok_batch_encode(
contexts, left_truncate_len=max_ctx_len
)
context_enc = context_enc.to(self.device)
attn_masks = attn_masks.to(self.device)
cont = self._model_generate(
context=context_enc,
attention_mask=attn_masks,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**kwargs,
)
cont_toks_list = cont.tolist()
for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only LM
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
cont_toks = cont_toks[context_enc.shape[1] :] # max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length
s = self.tok_decode(cont_toks) # encode, pad, and truncate contexts
context_enc, attn_masks = self.tok_batch_encode(
contexts, left_truncate_len=max_ctx_len
)
context_enc = context_enc.to(self.device)
attn_masks = attn_masks.to(self.device)
cont = self._model_generate(
context=context_enc,
attention_mask=attn_masks,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**kwargs,
)
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc cont_toks_list = cont.tolist()
for term in until: for cont_toks, context in zip(cont_toks_list, contexts):
if len(term) > 0: # ignore '' separator, for seq2seq case where # discard context + left-padding toks if using causal decoder-only LM
s = s.split(term)[0] if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
cont_toks = cont_toks[context_enc.shape[1] :]
res.append(s) s = self.tok_decode(cont_toks)
self.cache_hook.add_partial("greedy_until", (context, gen_kwargs), s) # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
if len(term) > 0: # ignore '' separator, for seq2seq case where
s = s.split(term)[0]
return re_ord.get_original(res) res[str(gen_kwargs)].append(
s
) # TODO: move this to res[-1].append(s) to separate per re_ord
self.cache_hook.add_partial(
"greedy_until", (context, gen_kwargs), s
)
pbar.update(1)
res[key] = re_ord.get_original(res[key])
pbar.close()
return grouper.get_original(res)
# return utils.join_iters([re_ord.get_original(rs) for re_ord, rs in zip(re_ords, res.values())])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment