Commit 8e8d7c6d authored by Leo Gao's avatar Leo Gao
Browse files

Add Reorderer and implement in gpt2 and gpt3

parent 1b4242c1
...@@ -35,7 +35,13 @@ class GPT2LM(LM): ...@@ -35,7 +35,13 @@ class GPT2LM(LM):
with torch.no_grad(): with torch.no_grad():
# TODO: vectorize properly # TODO: vectorize properly
# TODO: automatic batch size detection for vectorization # TODO: automatic batch size detection for vectorization
for context, continuation in tqdm(requests):
def _collate(x):
toks = self.tokenizer.encode(x[0] + x[1])[:-1]
return (len(toks), self.tokenizer.decode(toks))
reord = utils.Reorderer(requests, _collate)
for context, continuation in tqdm(reord.get_reordered()):
# when too long to fit in context, truncate from the left # when too long to fit in context, truncate from the left
if context == "": if context == "":
...@@ -59,14 +65,20 @@ class GPT2LM(LM): ...@@ -59,14 +65,20 @@ class GPT2LM(LM):
res.append((float(logits.sum()), bool(max_equal))) res.append((float(logits.sum()), bool(max_equal)))
return res return reord.get_original(res)
def greedy_until(self, requests): def greedy_until(self, requests):
# TODO: implement fully general `until` that handles untils that are # TODO: implement fully general `until` that handles untils that are
# multiple tokens or that span multiple tokens correctly # multiple tokens or that span multiple tokens correctly
res = [] res = []
for context, until in tqdm(requests): def _collate(x):
toks = self.tokenizer.encode(x[0])
return (len(toks), x[0])
reord = utils.Reorderer(requests, _collate)
for context, until in tqdm(reord.get_reordered()):
if isinstance(until, str): until = [until] if isinstance(until, str): until = [until]
context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - self.max_length:]]).to(self.device) context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - self.max_length:]]).to(self.device)
...@@ -87,4 +99,4 @@ class GPT2LM(LM): ...@@ -87,4 +99,4 @@ class GPT2LM(LM):
res.append(s) res.append(s)
return res return reord.get_original(res)
...@@ -70,7 +70,13 @@ class GPT3LM(LM): ...@@ -70,7 +70,13 @@ class GPT3LM(LM):
import openai import openai
res = [] res = []
for chunk in tqdm(list(utils.chunks(requests, self.REQ_CHUNK_SIZE))): def _collate(x):
toks = self.tokenizer.encode(x[0] + x[1])[:-1]
return (len(toks), self.tokenizer.decode(toks))
reord = utils.Reorderer(requests, _collate)
for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
inps = [] inps = []
ctxlens = [] ctxlens = []
for context, continuation in chunk: for context, continuation in chunk:
...@@ -98,13 +104,19 @@ class GPT3LM(LM): ...@@ -98,13 +104,19 @@ class GPT3LM(LM):
for resp, ctxlen in zip(response.choices, ctxlens): for resp, ctxlen in zip(response.choices, ctxlens):
res.append(get_result(resp, ctxlen)) res.append(get_result(resp, ctxlen))
return res return reord.get_original(res)
def greedy_until(self, requests): def greedy_until(self, requests):
if not requests: return [] if not requests: return []
import openai import openai
res = [] res = []
def _collate(x):
toks = self.tokenizer.encode(x[0])
return (len(toks), x[0])
reord = utils.Reorderer(requests, _collate)
def sameuntil_chunks(xs, size): def sameuntil_chunks(xs, size):
ret = [] ret = []
lastuntil = xs[0][1] lastuntil = xs[0][1]
...@@ -118,7 +130,7 @@ class GPT3LM(LM): ...@@ -118,7 +130,7 @@ class GPT3LM(LM):
if ret: yield ret, lastuntil if ret: yield ret, lastuntil
# todo: more intelligent batching for heterogenous `until` # todo: more intelligent batching for heterogenous `until`
for chunk, until in tqdm(list(sameuntil_chunks(requests, self.REQ_CHUNK_SIZE))): for chunk, until in tqdm(list(sameuntil_chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
inps = [] inps = []
for context, _ in chunk: for context, _ in chunk:
context_enc = self.tokenizer.encode(context) context_enc = self.tokenizer.encode(context)
...@@ -142,5 +154,5 @@ class GPT3LM(LM): ...@@ -142,5 +154,5 @@ class GPT3LM(LM):
res.append(s) res.append(s)
return res return reord.get_original(res)
import os import os
import re import re
import collections
class ExitCodeError(Exception): class ExitCodeError(Exception):
...@@ -42,6 +43,14 @@ def chunks(iter, n): ...@@ -42,6 +43,14 @@ def chunks(iter, n):
if arr: yield arr if arr: yield arr
def group(arr, fn):
res = collections.defaultdict(list)
for ob in arr:
res[fn(ob)].append(ob)
return list(res.values())
def general_detokenize(string): def general_detokenize(string):
string = string.replace(" n't", "n't") string = string.replace(" n't", "n't")
string = string.replace(" )", ")") string = string.replace(" )", ")")
...@@ -49,4 +58,34 @@ def general_detokenize(string): ...@@ -49,4 +58,34 @@ def general_detokenize(string):
string = string.replace("\" ", "\"") string = string.replace("\" ", "\"")
string = string.replace(" \"", "\"") string = string.replace(" \"", "\"")
string = re.sub(r" (['.,])", r"\1", string) string = re.sub(r" (['.,])", r"\1", string)
return string return string
\ No newline at end of file
class Reorderer:
def __init__(self, arr, fn):
self.size = len(arr)
arr = list(enumerate(arr))
arr = group(arr, lambda x: fn(x[1]))
arr = [
([y[0] for y in x], x[0][1]) for x in arr
]
arr.sort(key=lambda x: fn(x[1]))
self.arr = arr
def get_reordered(self):
return [x[1] for x in self.arr]
def get_original(self, newarr):
res = [None] * self.size
cov = [False] * self.size
for (inds, _), v in zip(self.arr, newarr):
for ind in inds:
res[ind] = v
cov[ind] = True
assert all(cov)
return res
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment