Commit b250b001 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

clean up batched code and add comments

parent ffc3a456
......@@ -534,26 +534,27 @@ class HFLM(LM):
toks = self.tok_encode(x[0])
return -len(toks), x[0]
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
pbar = tqdm(total=len(requests))
assert len(requests) == sum(
[len(list(re_ord.get_reordered())) for re_ord in re_ords.values()]
)
pbar = tqdm(total=len(requests), disable=(self.rank != 0))
# for each different set of kwargs, we execute all requests, by batch.
for key, re_ord in re_ords.items():
for chunk in utils.chunks(
# tqdm(
re_ord.get_reordered(),
# disable=(self.rank != 0),
# ),
self.batch_size,
):
contexts, all_gen_kwargs = zip(*chunk)
gen_kwargs = all_gen_kwargs[
0
] # TODO: handle case where not all gen kwargs are same
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
......@@ -586,13 +587,14 @@ class HFLM(LM):
# max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length
# encode, pad, and truncate contexts
# encode, pad, and truncate contexts for this batch
context_enc, attn_masks = self.tok_batch_encode(
contexts, left_truncate_len=max_ctx_len
)
context_enc = context_enc.to(self.device)
attn_masks = attn_masks.to(self.device)
# perform batched generation
cont = self._model_generate(
context=context_enc,
attention_mask=attn_masks,
......@@ -611,18 +613,20 @@ class HFLM(LM):
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
if len(term) > 0: # ignore '' separator, for seq2seq case where
if len(term) > 0:
# ignore '' separator,
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s = s.split(term)[0]
res[str(gen_kwargs)].append(
s
) # TODO: move this to res[-1].append(s) to separate per re_ord
res[key].append(s)
self.cache_hook.add_partial(
"greedy_until", (context, gen_kwargs), s
)
pbar.update(1)
# reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key])
pbar.close()
return grouper.get_original(res)
# return utils.join_iters([re_ord.get_original(rs) for re_ord, rs in zip(re_ords, res.values())])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment