Commit 1ff4e07f authored by Leo Gao's avatar Leo Gao
Browse files

Add gpt3 chunking

parent f3bf1c07
......@@ -37,7 +37,7 @@ def oa_completion(**kwargs):
class GPT3LM(LM):
MAX_LENGTH = 2048
REQ_CHUNK_SIZE = 64
REQ_CHUNK_SIZE = 20
MAX_GEN_TOKS = 256
def __init__(self, engine, truncate=False):
......@@ -101,28 +101,46 @@ class GPT3LM(LM):
return res
def greedy_until(self, requests):
if not requests: return []
import openai
res = []
for context, until in tqdm(requests):
context_enc = self.tokenizer.encode(context)
inp = context_enc[-(self.MAX_LENGTH - self.MAX_GEN_TOKS):]
ctxlen = len(context_enc) - max(0, len(context_enc) - (self.MAX_LENGTH - self.MAX_GEN_TOKS))
def sameuntil_chunks(xs, size):
ret = []
lastuntil = xs[0][1]
for x in xs:
if len(ret) >= size or x[1] != lastuntil:
yield ret, lastuntil
ret = []
lastuntil = x[1]
ret.append(x)
if ret: yield ret, lastuntil
# todo: more intelligent batching for heterogenous `until`
for chunk, until in tqdm(list(sameuntil_chunks(requests, self.REQ_CHUNK_SIZE))):
inps = []
for context, _ in chunk:
context_enc = self.tokenizer.encode(context)
inp = context_enc[-(self.MAX_LENGTH - self.MAX_GEN_TOKS):]
inps.append(inp)
response = oa_completion(
engine=self.engine,
prompt=[inp],
prompt=inps,
max_tokens=self.MAX_GEN_TOKS,
temperature=0.,
logprobs=10,
stop=until
)
s = response.choices[0]['text']
for term in until:
s = s.split(term)[0]
for resp in response.choices:
s = response.choices[0]['text']
for term in until:
s = s.split(term)[0]
res.append(s)
res.append(s)
return res
......@@ -8,4 +8,5 @@ transformers>=4.1
sqlitedict==1.6.0
pytablewriter==0.58.0
sacrebleu==1.5.0
pycountry==20.7.3
\ No newline at end of file
pycountry==20.7.3
numexpr==2.7.2
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment