Commit 8352e671 authored by Leo Gao's avatar Leo Gao
Browse files

Refactor gpt2 loglikelihood

parent eec18018
...@@ -35,6 +35,21 @@ class GPT2LM(LM): ...@@ -35,6 +35,21 @@ class GPT2LM(LM):
return cls(device=args.get("device", None), pretrained=args.get("pretrained", "gpt2")) return cls(device=args.get("device", None), pretrained=args.get("pretrained", "gpt2"))
def loglikelihood(self, requests): def loglikelihood(self, requests):
new_reqs = []
for context, continuation in requests:
if context == "":
# end of text as context
context_enc = [50256]
else:
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
new_reqs.append((context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def _loglikelihood_tokens(self, requests):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = [] res = []
with torch.no_grad(): with torch.no_grad():
...@@ -42,21 +57,12 @@ class GPT2LM(LM): ...@@ -42,21 +57,12 @@ class GPT2LM(LM):
# TODO: automatic batch size detection for vectorization # TODO: automatic batch size detection for vectorization
def _collate(x): def _collate(x):
toks = self.tokenizer.encode(x[0] + x[1]) toks = x[0] + x[1]
return (len(toks), x) return (len(toks), tuple(toks))
reord = utils.Reorderer(requests, _collate) reord = utils.Reorderer(requests, _collate)
for context, continuation in tqdm(reord.get_reordered()): for context_enc, continuation_enc in tqdm(reord.get_reordered()):
# when too long to fit in context, truncate from the left # when too long to fit in context, truncate from the left
combined_toks = self.tokenizer.encode(context + continuation)
if context == "":
# end of text as context
context_enc = [50256]
else:
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device) inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length) ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
...@@ -73,7 +79,8 @@ class GPT2LM(LM): ...@@ -73,7 +79,8 @@ class GPT2LM(LM):
answer = (float(logits.sum()), bool(max_equal)) answer = (float(logits.sum()), bool(max_equal))
# partial caching # partial caching
self.cache_hook.add_partial("loglikelihood", (context, continuation), answer) # TODO: make sure that decode reverses correctly
self.cache_hook.add_partial("loglikelihood", (self.tokenizer.decode(context_enc), self.tokenizer.decode(continuation_enc)), answer)
res.append(answer) res.append(answer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment