Commit a586a5c4 authored by Leo Gao's avatar Leo Gao
Browse files

More refactoring of model code

parent ab1fdc54
......@@ -45,7 +45,7 @@ class GPT2LM(LM):
continuation_enc = self.tokenizer.encode(continuation)
new_reqs.append((context_enc, continuation_enc))
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
......@@ -57,11 +57,11 @@ class GPT2LM(LM):
# TODO: automatic batch size detection for vectorization
def _collate(x):
toks = x[0] + x[1]
toks = x[1] + x[2]
return (len(toks), tuple(toks))
reord = utils.Reorderer(requests, _collate)
for context_enc, continuation_enc in tqdm(reord.get_reordered()):
for cache_key, context_enc, continuation_enc in tqdm(reord.get_reordered()):
# when too long to fit in context, truncate from the left
inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
......@@ -79,8 +79,8 @@ class GPT2LM(LM):
answer = (float(logits.sum()), bool(max_equal))
# partial caching
# TODO: make sure that decode reverses correctly
self.cache_hook.add_partial("loglikelihood", (self.tokenizer.decode(context_enc), self.tokenizer.decode(continuation_enc)), answer)
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
res.append(answer)
......
......@@ -78,7 +78,7 @@ class GPT3LM(LM):
continuation_enc = self.tokenizer.encode(continuation)
new_reqs.append((context_enc, continuation_enc))
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
......@@ -90,15 +90,15 @@ class GPT3LM(LM):
# this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
# it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
# we care about and so we need some kind of backup for when it isn't
toks = self.tokenizer.encode(x[0] + x[1])
return (len(toks), self.tokenizer.decode(toks))
toks = x[1] + x[2]
return (len(toks), tuple(toks))
reord = utils.Reorderer(requests, _collate)
for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
inps = []
ctxlens = []
for context_enc, continuation_enc in chunk:
for cache_key, context_enc, continuation_enc in chunk:
inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:]
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.MAX_LENGTH)
......@@ -113,14 +113,14 @@ class GPT3LM(LM):
logprobs=10,
)
for resp, ctxlen, (context_enc, continuation_enc) in zip(response.choices, ctxlens, chunk):
for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(response.choices, ctxlens, chunk):
answer = get_result(resp, ctxlen)
res.append(answer)
# partial caching
# TODO: make sure that decode reverses correctly
self.cache_hook.add_partial("loglikelihood", (self.tokenizer.decode(context_enc), self.tokenizer.decode(continuation_enc)), answer)
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return reord.get_original(res)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment