Commit eb8456d0 authored by Leo Gao's avatar Leo Gao
Browse files

Merge branch 'batching'

parents 7f7673ec 5f42f976
import transformers import transformers
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from lm_eval.base import LM from lm_eval.base import LM
from lm_eval import utils from lm_eval import utils
...@@ -9,7 +10,7 @@ from tqdm import tqdm ...@@ -9,7 +10,7 @@ from tqdm import tqdm
class GPT2LM(LM): class GPT2LM(LM):
MAX_GEN_TOKS = 256 MAX_GEN_TOKS = 256
def __init__(self, device=None, pretrained='gpt2'): def __init__(self, device='cuda', pretrained='gpt2', batch_size=1):
super().__init__() super().__init__()
if device: if device:
self.device = torch.device(device) self.device = torch.device(device)
...@@ -29,10 +30,21 @@ class GPT2LM(LM): ...@@ -29,10 +30,21 @@ class GPT2LM(LM):
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373] assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
# multithreading and batching
gpus = torch.cuda.device_count()
batch_size_per_gpu = batch_size # todo: adaptive batch size
self.batch_size = batch_size_per_gpu * gpus
# TODO: fix multi-gpu
# if gpus > 1:
# self.gpt2 = nn.DataParallel(self.gpt2)
@classmethod @classmethod
def create_from_arg_string(cls, arg_string): def create_from_arg_string(cls, arg_string, **kwargs):
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
return cls(device=args.get("device", None), pretrained=args.get("pretrained", "gpt2")) kwargs = {k: v for k, v in kwargs.items() if v is not None}
return cls(pretrained=args.get("pretrained", "gpt2"), **kwargs)
def loglikelihood(self, requests): def loglikelihood(self, requests):
new_reqs = [] new_reqs = []
...@@ -53,36 +65,64 @@ class GPT2LM(LM): ...@@ -53,36 +65,64 @@ class GPT2LM(LM):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = [] res = []
with torch.no_grad(): with torch.no_grad():
# TODO: vectorize properly
# TODO: automatic batch size detection for vectorization
def _collate(x): def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch padded context length.
# this is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = x[1] + x[2] toks = x[1] + x[2]
return (len(toks), tuple(toks)) return (-len(toks), tuple(toks))
# TODO: automatic (variable) batch size detection for vectorization
reord = utils.Reorderer(requests, _collate) reord = utils.Reorderer(requests, _collate)
for cache_key, context_enc, continuation_enc in tqdm(reord.get_reordered()): for chunk in utils.chunks(tqdm(reord.get_reordered()), self.batch_size):
# when too long to fit in context, truncate from the left inps = []
inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device) inplens = []
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length) ctxlens = []
padding_length = None
for _, context_enc, continuation_enc in chunk:
# when too long to fit in context, truncate from the left
inp = torch.tensor((context_enc + continuation_enc)[-self.max_length:], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
inplen, = inp.shape
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length = padding_length if padding_length is not None else inplen
# pad to length
inp = torch.cat([
inp, # [seq]
torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq]
], dim=0)
inps.append(inp.unsqueeze(0))
inplens.append(inplen)
ctxlens.append(ctxlen)
multi_logits = F.log_softmax(self.gpt2(torch.cat(inps, dim=0))[0][:, :, :50257], dim=-1) # [batch, seq, vocab]
cont_toks = inp[:, ctxlen:] # [batch, seq] for (cache_key, _, _), logits, ctxlen, inp, inplen in zip(chunk, multi_logits, ctxlens, inps, inplens):
logits = F.log_softmax(self.gpt2(inp)[0][:, :, :50257], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab] logits = logits[ctxlen - 1:inplen - 1].unsqueeze(0) # [1, seq, vocab]
greedy_tokens = logits.argmax(dim=-1) greedy_tokens = logits.argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all() cont_toks = inp[:, ctxlen:inplen] # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
last_token_slice = logits[:, -1, :].squeeze(0).tolist() last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq] logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]
answer = (float(logits.sum()), bool(max_equal)) answer = (float(logits.sum()), bool(max_equal))
# partial caching # partial caching
if cache_key is not None: if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer)
res.append(answer) res.append(answer)
return reord.get_original(res) return reord.get_original(res)
......
...@@ -63,9 +63,10 @@ class GPT3LM(LM): ...@@ -63,9 +63,10 @@ class GPT3LM(LM):
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"] openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
@classmethod @classmethod
def create_from_arg_string(cls, arg_string): def create_from_arg_string(cls, arg_string, **kwargs):
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
return cls(engine=args.get("engine", "davinci")) kwargs = {k: v for k, v in kwargs.items() if v is not None}
return cls(engine=args.get("engine", "davinci"), **kwargs)
def loglikelihood(self, requests): def loglikelihood(self, requests):
new_reqs = [] new_reqs = []
...@@ -91,7 +92,7 @@ class GPT3LM(LM): ...@@ -91,7 +92,7 @@ class GPT3LM(LM):
# it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations # it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
# we care about and so we need some kind of backup for when it isn't # we care about and so we need some kind of backup for when it isn't
toks = x[1] + x[2] toks = x[1] + x[2]
return (len(toks), tuple(toks)) return (-len(toks), tuple(toks))
reord = utils.Reorderer(requests, _collate) reord = utils.Reorderer(requests, _collate)
......
...@@ -94,8 +94,8 @@ TASK_REGISTRY = { ...@@ -94,8 +94,8 @@ TASK_REGISTRY = {
"drop": drop.DROP, "drop": drop.DROP,
"lambada": lambada.LAMBADA, "lambada": lambada.LAMBADA,
"lambada_cloze": lambada_cloze.LAMBADA_cloze, "lambada_cloze": lambada_cloze.LAMBADA_cloze,
"cbt-cn": cbt.CBTCN, # "cbt-cn": cbt.CBTCN, # disabled pending context length fix
"cbt-ne": cbt.CBTNE, # "cbt-ne": cbt.CBTNE, # disabled pending context length fix
"piqa": piqa.PiQA, "piqa": piqa.PiQA,
...@@ -107,7 +107,7 @@ TASK_REGISTRY = { ...@@ -107,7 +107,7 @@ TASK_REGISTRY = {
"qa4mre_2012" : qa4mre.QA4MRE_2012, "qa4mre_2012" : qa4mre.QA4MRE_2012,
"qa4mre_2013" : qa4mre.QA4MRE_2013, "qa4mre_2013" : qa4mre.QA4MRE_2013,
"triviaqa": triviaqa.TriviaQA, #"triviaqa": triviaqa.TriviaQA, # disabled pending memory fix
"arc_easy": arc.ARCEasy, "arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge, "arc_challenge": arc.ARCChallenge,
# "quac": quac.QuAC, # not implemented yet # "quac": quac.QuAC, # not implemented yet
......
...@@ -19,7 +19,7 @@ class Math(Task): ...@@ -19,7 +19,7 @@ class Math(Task):
sh(f""" sh(f"""
mkdir -p {self.DATASET_PATH} mkdir -p {self.DATASET_PATH}
wget https://people.eecs.berkeley.edu/~hendrycks/MATH.tar.gz -P data/ wget https://people.eecs.berkeley.edu/~hendrycks/MATH.tar.gz -P data/
tar -xvf {self.DATASET_PATH}.tar.gz -C data/ tar -xf {self.DATASET_PATH}.tar.gz -C data/
rm {self.DATASET_PATH}.tar.gz rm {self.DATASET_PATH}.tar.gz
""") """)
......
...@@ -23,7 +23,7 @@ class WordUnscrambleTask(Task): ...@@ -23,7 +23,7 @@ class WordUnscrambleTask(Task):
def download(self): def download(self):
if not self.BASE_PATH.exists(): if not self.BASE_PATH.exists():
Path.mkdir(self.BASE_PATH) Path.mkdir(self.BASE_PATH, parents=True)
file = self.BASE_PATH / self.FILENAME file = self.BASE_PATH / self.FILENAME
if not file.exists(): if not file.exists():
rawfile = file.parent / (file.name + ".gz") rawfile = file.parent / (file.name + ".gz")
......
...@@ -15,6 +15,8 @@ def parse_args(): ...@@ -15,6 +15,8 @@ def parse_args():
parser.add_argument('--tasks', default="all_tasks") parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--provide_description', action="store_true") parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0) parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None)
parser.add_argument('--device', type=int, default=None)
parser.add_argument('--seed', type=int, default=1234) parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--output_path', default=None) parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None) parser.add_argument('--limit', type=int, default=None)
...@@ -27,7 +29,7 @@ def main(): ...@@ -27,7 +29,7 @@ def main():
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
lm = models.get_model(args.model).create_from_arg_string(args.model_args) lm = models.get_model(args.model).create_from_arg_string(args.model_args, batch_size=args.batch_size, device=args.device)
if args.limit: if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.") print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
......
...@@ -4,9 +4,12 @@ import lm_eval.models as models ...@@ -4,9 +4,12 @@ import lm_eval.models as models
def test_gpt2(): def test_gpt2():
gpt2 = models.get_model('gpt2').create_from_arg_string("device=cpu") gpt2 = models.get_model('gpt2').create_from_arg_string("device=cpu")
(ll_dog, ig_dog), (ll_cat, ig_cat), *vals = gpt2.loglikelihood([ (ll_dog, ig_dog), (ll_cat, ig_cat), (_, ll_max_0), (_, ll_max_1), (_, ll_max_2), *vals = gpt2.loglikelihood([
('The quick brown fox jumps over the lazy', ' dog'), ('The quick brown fox jumps over the lazy', ' dog'),
('The quick brown fox jumps over the lazy', ' cat'), ('The quick brown fox jumps over the lazy', ' cat'),
('The quick brown fox jumps over the lazy', ', lazy dog'),
('The quick brown fox jumps over the lazy', ', lazy fox'),
('The quick brown fox jumps over the lazy', ', lazy fox and they both fall to the ground'),
("""A mult""", """ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)"""), ("""A mult""", """ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)"""),
("""The term MLP is used ambiguously, sometimes loosely to any feedforward ANN, sometimes strictly to refer to networks composed of multiple layers of perceptrons""", """ (with threshold activation); see § Terminology"""), ("""The term MLP is used ambiguously, sometimes loosely to any feedforward ANN, sometimes strictly to refer to networks composed of multiple layers of perceptrons""", """ (with threshold activation); see § Terminology"""),
...@@ -22,6 +25,10 @@ def test_gpt2(): ...@@ -22,6 +25,10 @@ def test_gpt2():
assert ll_dog > ll_cat assert ll_dog > ll_cat
assert not ig_cat assert not ig_cat
assert not ll_max_0
assert ll_max_1
assert ll_max_2
# test empty context # test empty context
gpt2.loglikelihood([('', 'test')]) gpt2.loglikelihood([('', 'test')])
...@@ -34,4 +41,4 @@ def test_gpt2(): ...@@ -34,4 +41,4 @@ def test_gpt2():
targets = [-61.60536193847656, -56.57843780517578, -62.131004333496094, -9.799489974975586, -153.96334838867188, -341.222900390625, -731.1475830078125, -61.60536193847656, -8.682319641113281] targets = [-61.60536193847656, -56.57843780517578, -62.131004333496094, -9.799489974975586, -153.96334838867188, -341.222900390625, -731.1475830078125, -61.60536193847656, -8.682319641113281]
for (pred, _), tgt in zip(vals, targets): for (pred, _), tgt in zip(vals, targets):
assert pred == pytest.approx(tgt) assert pred == pytest.approx(tgt, abs=1e-3)
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment