Commit ffc3a456 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

push WIP batched multi-kwarg code

parent 83f957bc
...@@ -3,6 +3,7 @@ import transformers ...@@ -3,6 +3,7 @@ import transformers
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
import copy import copy
from collections import defaultdict
from tqdm import tqdm from tqdm import tqdm
import torch.nn.functional as F import torch.nn.functional as F
...@@ -520,19 +521,33 @@ class HFLM(LM): ...@@ -520,19 +521,33 @@ class HFLM(LM):
return re_ord.get_original(res) return re_ord.get_original(res)
def greedy_until(self, requests): def greedy_until(self, requests):
res = [] res = defaultdict(list)
re_ords = {}
def _collate(x): def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(x[0]) toks = self.tok_encode(x[0])
return -len(toks), x[0] return -len(toks), x[0]
re_ord = utils.Reorderer([req.args for req in requests], _collate) grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
pbar = tqdm(total=len(requests))
assert len(requests) == sum(
[len(list(re_ord.get_reordered())) for re_ord in re_ords.values()]
)
for key, re_ord in re_ords.items():
for chunk in utils.chunks( for chunk in utils.chunks(
tqdm( # tqdm(
re_ord.get_reordered(), re_ord.get_reordered(),
disable=(self.rank != 0), # disable=(self.rank != 0),
), # ),
self.batch_size, self.batch_size,
): ):
contexts, all_gen_kwargs = zip(*chunk) contexts, all_gen_kwargs = zip(*chunk)
...@@ -599,8 +614,15 @@ class HFLM(LM): ...@@ -599,8 +614,15 @@ class HFLM(LM):
if len(term) > 0: # ignore '' separator, for seq2seq case where if len(term) > 0: # ignore '' separator, for seq2seq case where
s = s.split(term)[0] s = s.split(term)[0]
res.append(s) res[str(gen_kwargs)].append(
s
) # TODO: move this to res[-1].append(s) to separate per re_ord
self.cache_hook.add_partial("greedy_until", (context, gen_kwargs), s) self.cache_hook.add_partial(
"greedy_until", (context, gen_kwargs), s
return re_ord.get_original(res) )
pbar.update(1)
res[key] = re_ord.get_original(res[key])
pbar.close()
return grouper.get_original(res)
# return utils.join_iters([re_ord.get_original(rs) for re_ord, rs in zip(re_ords, res.values())])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment