Commit a586a5c4 authored by Leo Gao's avatar Leo Gao
Browse files

More refactoring of model code

parent ab1fdc54
...@@ -45,7 +45,7 @@ class GPT2LM(LM): ...@@ -45,7 +45,7 @@ class GPT2LM(LM):
continuation_enc = self.tokenizer.encode(continuation) continuation_enc = self.tokenizer.encode(continuation)
new_reqs.append((context_enc, continuation_enc)) new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs) return self._loglikelihood_tokens(new_reqs)
...@@ -57,11 +57,11 @@ class GPT2LM(LM): ...@@ -57,11 +57,11 @@ class GPT2LM(LM):
# TODO: automatic batch size detection for vectorization # TODO: automatic batch size detection for vectorization
def _collate(x): def _collate(x):
toks = x[0] + x[1] toks = x[1] + x[2]
return (len(toks), tuple(toks)) return (len(toks), tuple(toks))
reord = utils.Reorderer(requests, _collate) reord = utils.Reorderer(requests, _collate)
for context_enc, continuation_enc in tqdm(reord.get_reordered()): for cache_key, context_enc, continuation_enc in tqdm(reord.get_reordered()):
# when too long to fit in context, truncate from the left # when too long to fit in context, truncate from the left
inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device) inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length) ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
...@@ -79,8 +79,8 @@ class GPT2LM(LM): ...@@ -79,8 +79,8 @@ class GPT2LM(LM):
answer = (float(logits.sum()), bool(max_equal)) answer = (float(logits.sum()), bool(max_equal))
# partial caching # partial caching
# TODO: make sure that decode reverses correctly if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", (self.tokenizer.decode(context_enc), self.tokenizer.decode(continuation_enc)), answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer)
res.append(answer) res.append(answer)
......
...@@ -78,7 +78,7 @@ class GPT3LM(LM): ...@@ -78,7 +78,7 @@ class GPT3LM(LM):
continuation_enc = self.tokenizer.encode(continuation) continuation_enc = self.tokenizer.encode(continuation)
new_reqs.append((context_enc, continuation_enc)) new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs) return self._loglikelihood_tokens(new_reqs)
...@@ -90,15 +90,15 @@ class GPT3LM(LM): ...@@ -90,15 +90,15 @@ class GPT3LM(LM):
# this doesn't efficiently handle last-token differences yet, but those are kinda annoying because # this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
# it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
# we care about and so we need some kind of backup for when it isn't # we care about and so we need some kind of backup for when it isn't
toks = self.tokenizer.encode(x[0] + x[1]) toks = x[1] + x[2]
return (len(toks), self.tokenizer.decode(toks)) return (len(toks), tuple(toks))
reord = utils.Reorderer(requests, _collate) reord = utils.Reorderer(requests, _collate)
for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))): for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
inps = [] inps = []
ctxlens = [] ctxlens = []
for context_enc, continuation_enc in chunk: for cache_key, context_enc, continuation_enc in chunk:
inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:] inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:]
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.MAX_LENGTH) ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.MAX_LENGTH)
...@@ -113,14 +113,14 @@ class GPT3LM(LM): ...@@ -113,14 +113,14 @@ class GPT3LM(LM):
logprobs=10, logprobs=10,
) )
for resp, ctxlen, (context_enc, continuation_enc) in zip(response.choices, ctxlens, chunk): for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(response.choices, ctxlens, chunk):
answer = get_result(resp, ctxlen) answer = get_result(resp, ctxlen)
res.append(answer) res.append(answer)
# partial caching # partial caching
# TODO: make sure that decode reverses correctly if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", (self.tokenizer.decode(context_enc), self.tokenizer.decode(continuation_enc)), answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return reord.get_original(res) return reord.get_original(res)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment