Commit be3a6a2d authored by Leo Gao's avatar Leo Gao
Browse files

Initial implementation of gpt2 batching

parent 8846bec0
import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
from lm_eval.base import LM
from lm_eval import utils
......@@ -29,6 +30,15 @@ class GPT2LM(LM):
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
# multithreading and batching
gpus = torch.cuda.device_count()
batch_size_per_gpu = 2 # todo: adaptive batch size
self.batch_size = batch_size_per_gpu * gpus
if gpus > 1:
self.gpt2 = nn.DataParallel(self.gpt2)
@classmethod
def create_from_arg_string(cls, arg_string):
args = utils.simple_parse_args_string(arg_string)
......@@ -53,23 +63,52 @@ class GPT2LM(LM):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
with torch.no_grad():
# TODO: vectorize properly
# TODO: automatic batch size detection for vectorization
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch padded context length.
# this is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = x[1] + x[2]
return (len(toks), tuple(toks))
return (-len(toks), tuple(toks))
# TODO: automatic (variable) batch size detection for vectorization
reord = utils.Reorderer(requests, _collate)
for cache_key, context_enc, continuation_enc in tqdm(reord.get_reordered()):
for chunk in utils.chunks(tqdm(reord.get_reordered()), self.batch_size):
inps = []
ctxlens = []
inplens = []
padding_length = None
for _, context_enc, continuation_enc in chunk:
# when too long to fit in context, truncate from the left
inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device)
inp = torch.tensor((context_enc + continuation_enc)[-self.max_length:], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
inplen, = inp.shape
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0][:, :, :50257], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length = padding_length if padding_length is not None else inplen
# pad to length
inp = torch.cat([
inp, # [seq]
torch.zeros(padding_length - inplen, dtype=torch.long) # [padding_length - seq]
], dim=0)
inps.append(inp)
ctxlens.append(ctxlen)
inplens.append(inplen)
multi_logits = F.log_softmax(self.gpt2(torch.stack(inps, dim=0))[0][:, :, :50257], dim=-1) # [batch, seq, vocab]
for (cache_key, _, _), logits, ctxlen, inplens in zip(chunk, multi_logits, ctxlens, inplens):
logits = logits[ctxlen - 1:inplen - 1] # [seq, vocab]
greedy_tokens = logits.argmax(dim=-1)
cont_toks = inp[:, ctxlen:] # [batch, seq]
max_equal = (greedy_tokens == cont_toks).all()
last_token_slice = logits[:, -1, :].squeeze(0).tolist()
......
......@@ -91,7 +91,7 @@ class GPT3LM(LM):
# it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
# we care about and so we need some kind of backup for when it isn't
toks = x[1] + x[2]
return (len(toks), tuple(toks))
return (-len(toks), tuple(toks))
reord = utils.Reorderer(requests, _collate)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment