Commit 52c1c56a authored by Leo Gao's avatar Leo Gao
Browse files

Implement gpt3 logprobs

parent b57d059a
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
import os import os
import transformers import transformers
from lm_eval.base import LM from lm_eval.base import LM
from lm_eval import utils from lm_eval import utils
from tqdm import tqdm
def get_result(response, ctxlen):
is_greedy = True
logprobs = response["logprobs"]["token_logprobs"]
continuation_logprobs = sum(logprobs[ctxlen:])
for i in range(ctxlen, len(response["logprobs"]["tokens"])):
token = response["logprobs"]["tokens"][i]
print('TOK', token, response["logprobs"]["top_logprobs"][i])
top_tokens = response["logprobs"]["top_logprobs"][i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token:
is_greedy = False
break
return continuation_logprobs, is_greedy
class GPT3LM(LM): class GPT3LM(LM):
MAX_LENGTH = 2048 MAX_LENGTH = 2048
REQ_CHUNK_SIZE = 64
def __init__(self, engine, truncate=False): def __init__(self, engine, truncate=False):
""" """
...@@ -31,23 +48,36 @@ class GPT3LM(LM): ...@@ -31,23 +48,36 @@ class GPT3LM(LM):
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
return cls(engine=args.get("engine", "davinci")) return cls(engine=args.get("engine", "davinci"))
def loglikelihood(self, context, continuation): def loglikelihood(self, requests):
# TODO: implement new framework
import openai import openai
res = []
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation) for chunk in tqdm(utils.chunks(requests, self.REQ_CHUNK_SIZE)):
inp = (context_enc + continuation_enc)[-1024:] inps = []
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024) ctxlens = []
for context, continuation in chunk:
response = openai.Completion.create( print(context)
engine=self.engine, context_enc = self.tokenizer.encode(context)
prompt=inp, continuation_enc = self.tokenizer.encode(continuation)
echo=True, inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:]
max_tokens=0, temperature=0.0, ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.MAX_LENGTH)
logprobs=0,
) inps.append(inp)
logprobs = response.choices[0]["logprobs"]["token_logprobs"] ctxlens.append(ctxlen)
continuation_logprobs = logprobs[ctxlen:]
return sum(continuation_logprobs) response = openai.Completion.create(
engine=self.engine,
prompt=inps,
echo=True,
max_tokens=0, temperature=0.,
logprobs=10,
)
for resp, ctxlen in zip(response.choices, ctxlens):
res.append(get_result(resp, ctxlen))
return res
def greedy_until(self, requests):
# TODO: implement
pass
\ No newline at end of file
...@@ -29,3 +29,14 @@ def simple_parse_args_string(args_string): ...@@ -29,3 +29,14 @@ def simple_parse_args_string(args_string):
def join_iters(iters): def join_iters(iters):
for iter in iters: for iter in iters:
yield from iter yield from iter
def chunks(iter, n):
arr = []
for x in iter:
arr.append(x)
if len(arr) == n:
yield arr
arr = []
if arr: yield arr
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment