Commit 5f4c7c50 authored by Leo Gao's avatar Leo Gao
Browse files

Implement gpt3 greedy_until

parent 38e8858f
......@@ -25,6 +25,7 @@ class GPT3LM(LM):
MAX_LENGTH = 2048
REQ_CHUNK_SIZE = 64
MAX_GEN_TOKS = 256
def __init__(self, engine, truncate=False):
"""
......@@ -48,9 +49,6 @@ class GPT3LM(LM):
return cls(engine=args.get("engine", "davinci"))
def loglikelihood(self, requests):
import openai
res = []
for chunk in tqdm(utils.chunks(requests, self.REQ_CHUNK_SIZE)):
inps = []
ctxlens = []
......@@ -78,5 +76,23 @@ class GPT3LM(LM):
return res
def greedy_until(self, requests):
# TODO: implement
pass
\ No newline at end of file
import openai
res = []
for context, until in tqdm(requests):
context_enc = self.tokenizer.encode(context)
inp = context_enc[-(self.MAX_LENGTH - self.MAX_GEN_TOKS):]
ctxlen = len(context_enc) - max(0, len(context_enc) - (self.MAX_LENGTH - self.MAX_GEN_TOKS))
response = openai.Completion.create(
engine=self.engine,
prompt=[inp],
max_tokens=self.MAX_GEN_TOKS,
temperature=0.,
logprobs=10,
)
res.append(response.choices[0]['text'])
return res
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment