Commit f3fee648 authored by Leo Gao's avatar Leo Gao
Browse files

Roll back last token optimization

parent 59a0104d
...@@ -41,12 +41,13 @@ class GPT2LM(LM): ...@@ -41,12 +41,13 @@ class GPT2LM(LM):
# TODO: automatic batch size detection for vectorization # TODO: automatic batch size detection for vectorization
def _collate(x): def _collate(x):
toks = self.tokenizer.encode(x[0] + x[1])[:-1] toks = self.tokenizer.encode(x[0] + x[1])
return (len(toks), self.tokenizer.decode(toks)) return (len(toks), x)
reord = utils.Reorderer(requests, _collate) reord = utils.Reorderer(requests, _collate)
for context, continuation in tqdm(reord.get_reordered()): for context, continuation in tqdm(reord.get_reordered()):
# when too long to fit in context, truncate from the left # when too long to fit in context, truncate from the left
combined_toks = self.tokenizer.encode(context + continuation)
if context == "": if context == "":
# end of text as context # end of text as context
...@@ -68,12 +69,9 @@ class GPT2LM(LM): ...@@ -68,12 +69,9 @@ class GPT2LM(LM):
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq] logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
res.append((float(logits[:, :-1].sum() if logits.shape[-1] > 1 else 0), last_token_slice, bool(max_equal))) res.append((float(logits.sum()), bool(max_equal)))
# optimization: if two requests have everything the same except the last token, use return reord.get_original(res)
# last token distribution to save compute
lasttoks = [self.tokenizer.encode(x[1])[-1] for x in requests]
return [(l + lts[lasttok], m) for (l, lts, m), lasttok in zip(reord.get_original(res), lasttoks)]
def greedy_until(self, requests): def greedy_until(self, requests):
# TODO: implement fully general `until` that handles untils that are # TODO: implement fully general `until` that handles untils that are
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment