Commit be3a6a2d authored by Leo Gao's avatar Leo Gao
Browse files

Initial implementation of gpt2 batching

parent 8846bec0
import transformers import transformers
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from lm_eval.base import LM from lm_eval.base import LM
from lm_eval import utils from lm_eval import utils
...@@ -29,6 +30,15 @@ class GPT2LM(LM): ...@@ -29,6 +30,15 @@ class GPT2LM(LM):
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373] assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
# multithreading and batching
gpus = torch.cuda.device_count()
batch_size_per_gpu = 2 # todo: adaptive batch size
self.batch_size = batch_size_per_gpu * gpus
if gpus > 1:
self.gpt2 = nn.DataParallel(self.gpt2)
@classmethod @classmethod
def create_from_arg_string(cls, arg_string): def create_from_arg_string(cls, arg_string):
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
...@@ -53,36 +63,65 @@ class GPT2LM(LM): ...@@ -53,36 +63,65 @@ class GPT2LM(LM):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = [] res = []
with torch.no_grad(): with torch.no_grad():
# TODO: vectorize properly
# TODO: automatic batch size detection for vectorization
def _collate(x): def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch padded context length.
# this is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = x[1] + x[2] toks = x[1] + x[2]
return (len(toks), tuple(toks)) return (-len(toks), tuple(toks))
# TODO: automatic (variable) batch size detection for vectorization
reord = utils.Reorderer(requests, _collate) reord = utils.Reorderer(requests, _collate)
for cache_key, context_enc, continuation_enc in tqdm(reord.get_reordered()): for chunk in utils.chunks(tqdm(reord.get_reordered()), self.batch_size):
# when too long to fit in context, truncate from the left inps = []
inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device) ctxlens = []
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length) inplens = []
padding_length = None
for _, context_enc, continuation_enc in chunk:
# when too long to fit in context, truncate from the left
inp = torch.tensor((context_enc + continuation_enc)[-self.max_length:], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
inplen, = inp.shape
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length = padding_length if padding_length is not None else inplen
# pad to length
inp = torch.cat([
inp, # [seq]
torch.zeros(padding_length - inplen, dtype=torch.long) # [padding_length - seq]
], dim=0)
inps.append(inp)
ctxlens.append(ctxlen)
inplens.append(inplen)
multi_logits = F.log_softmax(self.gpt2(torch.stack(inps, dim=0))[0][:, :, :50257], dim=-1) # [batch, seq, vocab]
cont_toks = inp[:, ctxlen:] # [batch, seq] for (cache_key, _, _), logits, ctxlen, inplens in zip(chunk, multi_logits, ctxlens, inplens):
logits = F.log_softmax(self.gpt2(inp)[0][:, :, :50257], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab] logits = logits[ctxlen - 1:inplen - 1] # [seq, vocab]
greedy_tokens = logits.argmax(dim=-1) greedy_tokens = logits.argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all()
cont_toks = inp[:, ctxlen:] # [batch, seq]
max_equal = (greedy_tokens == cont_toks).all()
last_token_slice = logits[:, -1, :].squeeze(0).tolist() last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq] logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
answer = (float(logits.sum()), bool(max_equal)) answer = (float(logits.sum()), bool(max_equal))
# partial caching # partial caching
if cache_key is not None: if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer)
res.append(answer) res.append(answer)
return reord.get_original(res) return reord.get_original(res)
......
...@@ -91,7 +91,7 @@ class GPT3LM(LM): ...@@ -91,7 +91,7 @@ class GPT3LM(LM):
# it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
# we care about and so we need some kind of backup for when it isn't # we care about and so we need some kind of backup for when it isn't
toks = x[1] + x[2] toks = x[1] + x[2]
return (len(toks), tuple(toks)) return (-len(toks), tuple(toks))
reord = utils.Reorderer(requests, _collate) reord = utils.Reorderer(requests, _collate)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment