"docs/en/user_guides/2_new_data_model.md" did not exist on "7dfaf22b9f445ee65dbbebb4ad93911f7873eb8c"
Commit 8352e671 authored by Leo Gao's avatar Leo Gao
Browse files

Refactor gpt2 loglikelihood

parent eec18018
......@@ -35,6 +35,21 @@ class GPT2LM(LM):
return cls(device=args.get("device", None), pretrained=args.get("pretrained", "gpt2"))
def loglikelihood(self, requests):
new_reqs = []
for context, continuation in requests:
if context == "":
# end of text as context
context_enc = [50256]
else:
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
new_reqs.append((context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def _loglikelihood_tokens(self, requests):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
with torch.no_grad():
......@@ -42,21 +57,12 @@ class GPT2LM(LM):
# TODO: automatic batch size detection for vectorization
def _collate(x):
toks = self.tokenizer.encode(x[0] + x[1])
return (len(toks), x)
toks = x[0] + x[1]
return (len(toks), tuple(toks))
reord = utils.Reorderer(requests, _collate)
for context, continuation in tqdm(reord.get_reordered()):
for context_enc, continuation_enc in tqdm(reord.get_reordered()):
# when too long to fit in context, truncate from the left
combined_toks = self.tokenizer.encode(context + continuation)
if context == "":
# end of text as context
context_enc = [50256]
else:
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
......@@ -73,7 +79,8 @@ class GPT2LM(LM):
answer = (float(logits.sum()), bool(max_equal))
# partial caching
self.cache_hook.add_partial("loglikelihood", (context, continuation), answer)
# TODO: make sure that decode reverses correctly
self.cache_hook.add_partial("loglikelihood", (self.tokenizer.decode(context_enc), self.tokenizer.decode(continuation_enc)), answer)
res.append(answer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment