Commit ab1fdc54 authored by Leo Gao's avatar Leo Gao
Browse files

Refactor gpt3 loglikelihood

parent 8352e671
......@@ -68,6 +68,21 @@ class GPT3LM(LM):
return cls(engine=args.get("engine", "davinci"))
def loglikelihood(self, requests):
new_reqs = []
for context, continuation in requests:
if context == "":
# end of text as context
context_enc = [50256]
else:
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
new_reqs.append((context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def _loglikelihood_tokens(self, requests):
import openai
res = []
......@@ -83,14 +98,7 @@ class GPT3LM(LM):
for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
inps = []
ctxlens = []
for context, continuation in chunk:
if context == "":
# end of text as context
context_enc = [50256]
else:
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
for context_enc, continuation_enc in chunk:
inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:]
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.MAX_LENGTH)
......@@ -105,13 +113,14 @@ class GPT3LM(LM):
logprobs=10,
)
for resp, ctxlen, (context, continuation) in zip(response.choices, ctxlens, chunk):
for resp, ctxlen, (context_enc, continuation_enc) in zip(response.choices, ctxlens, chunk):
answer = get_result(resp, ctxlen)
res.append(answer)
# partial caching
self.cache_hook.add_partial("loglikelihood", (context, continuation), answer)
# TODO: make sure that decode reverses correctly
self.cache_hook.add_partial("loglikelihood", (self.tokenizer.decode(context_enc), self.tokenizer.decode(continuation_enc)), answer)
return reord.get_original(res)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment