Commit 52c1c56a authored by Leo Gao's avatar Leo Gao
Browse files

Implement gpt3 logprobs

parent b57d059a
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
import os
import transformers
from lm_eval.base import LM
from lm_eval import utils
from tqdm import tqdm
def get_result(response, ctxlen):
is_greedy = True
logprobs = response["logprobs"]["token_logprobs"]
continuation_logprobs = sum(logprobs[ctxlen:])
for i in range(ctxlen, len(response["logprobs"]["tokens"])):
token = response["logprobs"]["tokens"][i]
print('TOK', token, response["logprobs"]["top_logprobs"][i])
top_tokens = response["logprobs"]["top_logprobs"][i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token:
is_greedy = False
break
return continuation_logprobs, is_greedy
class GPT3LM(LM):
MAX_LENGTH = 2048
REQ_CHUNK_SIZE = 64
def __init__(self, engine, truncate=False):
"""
......@@ -31,23 +48,36 @@ class GPT3LM(LM):
args = utils.simple_parse_args_string(arg_string)
return cls(engine=args.get("engine", "davinci"))
def loglikelihood(self, context, continuation):
# TODO: implement new framework
def loglikelihood(self, requests):
import openai
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = (context_enc + continuation_enc)[-1024:]
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
response = openai.Completion.create(
engine=self.engine,
prompt=inp,
echo=True,
max_tokens=0, temperature=0.0,
logprobs=0,
)
logprobs = response.choices[0]["logprobs"]["token_logprobs"]
continuation_logprobs = logprobs[ctxlen:]
return sum(continuation_logprobs)
res = []
for chunk in tqdm(utils.chunks(requests, self.REQ_CHUNK_SIZE)):
inps = []
ctxlens = []
for context, continuation in chunk:
print(context)
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:]
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.MAX_LENGTH)
inps.append(inp)
ctxlens.append(ctxlen)
response = openai.Completion.create(
engine=self.engine,
prompt=inps,
echo=True,
max_tokens=0, temperature=0.,
logprobs=10,
)
for resp, ctxlen in zip(response.choices, ctxlens):
res.append(get_result(resp, ctxlen))
return res
def greedy_until(self, requests):
# TODO: implement
pass
\ No newline at end of file
......@@ -29,3 +29,14 @@ def simple_parse_args_string(args_string):
def join_iters(iters):
for iter in iters:
yield from iter
def chunks(iter, n):
arr = []
for x in iter:
arr.append(x)
if len(arr) == n:
yield arr
arr = []
if arr: yield arr
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment