Commit 1ff4e07f authored by Leo Gao's avatar Leo Gao
Browse files

Add gpt3 chunking

parent f3bf1c07
...@@ -37,7 +37,7 @@ def oa_completion(**kwargs): ...@@ -37,7 +37,7 @@ def oa_completion(**kwargs):
class GPT3LM(LM): class GPT3LM(LM):
MAX_LENGTH = 2048 MAX_LENGTH = 2048
REQ_CHUNK_SIZE = 64 REQ_CHUNK_SIZE = 20
MAX_GEN_TOKS = 256 MAX_GEN_TOKS = 256
def __init__(self, engine, truncate=False): def __init__(self, engine, truncate=False):
...@@ -101,28 +101,46 @@ class GPT3LM(LM): ...@@ -101,28 +101,46 @@ class GPT3LM(LM):
return res return res
def greedy_until(self, requests): def greedy_until(self, requests):
if not requests: return []
import openai import openai
res = [] res = []
for context, until in tqdm(requests): def sameuntil_chunks(xs, size):
context_enc = self.tokenizer.encode(context) ret = []
inp = context_enc[-(self.MAX_LENGTH - self.MAX_GEN_TOKS):] lastuntil = xs[0][1]
ctxlen = len(context_enc) - max(0, len(context_enc) - (self.MAX_LENGTH - self.MAX_GEN_TOKS)) for x in xs:
if len(ret) >= size or x[1] != lastuntil:
yield ret, lastuntil
ret = []
lastuntil = x[1]
ret.append(x)
if ret: yield ret, lastuntil
# todo: more intelligent batching for heterogenous `until`
for chunk, until in tqdm(list(sameuntil_chunks(requests, self.REQ_CHUNK_SIZE))):
inps = []
for context, _ in chunk:
context_enc = self.tokenizer.encode(context)
inp = context_enc[-(self.MAX_LENGTH - self.MAX_GEN_TOKS):]
inps.append(inp)
response = oa_completion( response = oa_completion(
engine=self.engine, engine=self.engine,
prompt=[inp], prompt=inps,
max_tokens=self.MAX_GEN_TOKS, max_tokens=self.MAX_GEN_TOKS,
temperature=0., temperature=0.,
logprobs=10, logprobs=10,
stop=until stop=until
) )
s = response.choices[0]['text']
for term in until: for resp in response.choices:
s = s.split(term)[0] s = response.choices[0]['text']
for term in until:
s = s.split(term)[0]
res.append(s) res.append(s)
return res return res
...@@ -8,4 +8,5 @@ transformers>=4.1 ...@@ -8,4 +8,5 @@ transformers>=4.1
sqlitedict==1.6.0 sqlitedict==1.6.0
pytablewriter==0.58.0 pytablewriter==0.58.0
sacrebleu==1.5.0 sacrebleu==1.5.0
pycountry==20.7.3 pycountry==20.7.3
\ No newline at end of file numexpr==2.7.2
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment