Commit f3fee648 authored by Leo Gao's avatar Leo Gao
Browse files

Roll back last token optimization

parent 59a0104d
......@@ -41,26 +41,27 @@ class GPT2LM(LM):
# TODO: automatic batch size detection for vectorization
def _collate(x):
toks = self.tokenizer.encode(x[0] + x[1])[:-1]
return (len(toks), self.tokenizer.decode(toks))
toks = self.tokenizer.encode(x[0] + x[1])
return (len(toks), x)
reord = utils.Reorderer(requests, _collate)
for context, continuation in tqdm(reord.get_reordered()):
# when too long to fit in context, truncate from the left
combined_toks = self.tokenizer.encode(context + continuation)
if context == "":
# end of text as context
context_enc = [50256]
else:
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
greedy_tokens = logits.argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all()
......@@ -68,12 +69,9 @@ class GPT2LM(LM):
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
res.append((float(logits[:, :-1].sum() if logits.shape[-1] > 1 else 0), last_token_slice, bool(max_equal)))
res.append((float(logits.sum()), bool(max_equal)))
# optimization: if two requests have everything the same except the last token, use
# last token distribution to save compute
lasttoks = [self.tokenizer.encode(x[1])[-1] for x in requests]
return [(l + lts[lasttok], m) for (l, lts, m), lasttok in zip(reord.get_original(res), lasttoks)]
return reord.get_original(res)
def greedy_until(self, requests):
# TODO: implement fully general `until` that handles untils that are
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment