"vscode:/vscode.git/clone" did not exist on "5da7cd6950b0e9a5db528d62a21b4dd9466ba9d1"
Commit 5f4c7c50 authored by Leo Gao's avatar Leo Gao
Browse files

Implement gpt3 greedy_until

parent 38e8858f
...@@ -25,6 +25,7 @@ class GPT3LM(LM): ...@@ -25,6 +25,7 @@ class GPT3LM(LM):
MAX_LENGTH = 2048 MAX_LENGTH = 2048
REQ_CHUNK_SIZE = 64 REQ_CHUNK_SIZE = 64
MAX_GEN_TOKS = 256
def __init__(self, engine, truncate=False): def __init__(self, engine, truncate=False):
""" """
...@@ -48,9 +49,6 @@ class GPT3LM(LM): ...@@ -48,9 +49,6 @@ class GPT3LM(LM):
return cls(engine=args.get("engine", "davinci")) return cls(engine=args.get("engine", "davinci"))
def loglikelihood(self, requests): def loglikelihood(self, requests):
import openai
res = []
for chunk in tqdm(utils.chunks(requests, self.REQ_CHUNK_SIZE)): for chunk in tqdm(utils.chunks(requests, self.REQ_CHUNK_SIZE)):
inps = [] inps = []
ctxlens = [] ctxlens = []
...@@ -78,5 +76,23 @@ class GPT3LM(LM): ...@@ -78,5 +76,23 @@ class GPT3LM(LM):
return res return res
def greedy_until(self, requests): def greedy_until(self, requests):
# TODO: implement import openai
pass res = []
\ No newline at end of file
for context, until in tqdm(requests):
context_enc = self.tokenizer.encode(context)
inp = context_enc[-(self.MAX_LENGTH - self.MAX_GEN_TOKS):]
ctxlen = len(context_enc) - max(0, len(context_enc) - (self.MAX_LENGTH - self.MAX_GEN_TOKS))
response = openai.Completion.create(
engine=self.engine,
prompt=[inp],
max_tokens=self.MAX_GEN_TOKS,
temperature=0.,
logprobs=10,
)
res.append(response.choices[0]['text'])
return res
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment