Commit fea936e6 authored by Leo Gao's avatar Leo Gao
Browse files

Merge TokenizedLM and TorchLM into BaseLM

parent 7f24a08b
......@@ -4,16 +4,17 @@ from typing import Iterable
import numpy as np
import re
from tqdm import tqdm
import torch
from lm_eval.metrics import mean, perplexity, weighted_perplexity, weighted_mean
from lm_eval import utils
from abc import abstractmethod
class LM(abc.ABC):
def __init__(self):
self.cache_hook = CacheHook(None)
@abc.abstractmethod
@abstractmethod
def loglikelihood(self, requests):
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
......@@ -37,7 +38,7 @@ class LM(abc.ABC):
"""
pass
@abc.abstractmethod
@abstractmethod
def loglikelihood_rolling(self, requests):
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
......@@ -80,7 +81,7 @@ class LM(abc.ABC):
pass
# TODO: Add an optional max length
@abc.abstractmethod
@abstractmethod
def greedy_until(self, requests):
"""Generate greedily until a stopping sequence
......@@ -108,15 +109,26 @@ class LM(abc.ABC):
self.cache_hook = cache_hook
class TokenizedLM(LM):
@abc.abstractmethod
class BaseLM(LM):
@abstractmethod
def tok_encode(self, string: str): pass
@abc.abstractmethod
@abstractmethod
def tok_decode(self, tokens: Iterable[int]): pass
@abc.abstractmethod
def _loglikelihood_tokens(self, requests, disable_tqdm=False): pass
@abstractmethod
def _model_generate(self, context, max_length, eos_token_id): pass
@abstractmethod
def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits retuned from the model
"""
pass
# subclass must implement properties vocab_size, eot_token_id, max_gen_toks.
# TODO: enforce this somehow
......@@ -162,6 +174,132 @@ class TokenizedLM(LM):
return loglikelihoods
# subclass must implement properties batch_size, vocab_size, eot_token_id, max_gen_toks, device.
# TODO: enforce this somehow
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch padded context length.
# this is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = x[1] + x[2]
return (-len(toks), tuple(toks))
# TODO: automatic (variable) batch size detection for vectorization
reord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
inps = []
contlens = []
inplens = []
padding_length = None
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying
for _, context_enc, continuation_enc in chunk:
# sanity check
assert len(context_enc) > 0
assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.vocab_size] slice
# cont_toks 4 5 6 7 8 9
# when too long to fit in context, truncate from the left
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length+1):][:-1]
, dtype=torch.long).to(self.device)
inplen, = inp.shape
cont = continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length = padding_length if padding_length is not None else inplen
# pad to length
inp = torch.cat([
inp, # [seq]
torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq]
], dim=0)
inps.append(inp.unsqueeze(0))
contlens.append(cont)
inplens.append(inplen)
multi_logits = F.log_softmax(self._model_call(torch.cat(inps, dim=0)), dim=-1).cpu() # [batch, seq, vocab]
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(chunk, multi_logits, inps, inplens, contlens):
contlen = len(cont_toks)
logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab]
greedy_tokens = logits.argmax(dim=-1)
# cont_toks :: [1, seq]
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0)
max_equal = (greedy_tokens == cont_toks).all()
#last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]
answer = (float(logits.sum()), bool(max_equal))
# partial caching
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
res.append(answer)
return reord.get_original(res)
def greedy_until(self, requests):
# TODO: implement fully general `until` that handles untils that are
# multiple tokens or that span multiple tokens correctly
# TODO: extract to TokenizedLM?
res = []
def _collate(x):
toks = self.tok_encode(x[0])
return (len(toks), x[0])
reord = utils.Reorderer(requests, _collate)
for context, until in tqdm(reord.get_reordered()):
if isinstance(until, str): until = [until]
primary_until, = self.tok_encode(until[0])
context_enc = torch.tensor([self.tok_encode(context)[self.max_gen_toks - self.max_length:]]).to(self.device)
cont = self._model_generate(context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until)
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1]:])
for term in until:
s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s)
return reord.get_original(res)
class Task(abc.ABC):
"""A task represents an entire benchmark including its dataset, problems,
......@@ -181,17 +319,17 @@ class Task(abc.ABC):
"""Downloads the task dataset if necessary"""
pass
@abc.abstractmethod
@abstractmethod
def has_training_docs(self):
"""Whether the task has a training set"""
pass
@abc.abstractmethod
@abstractmethod
def has_validation_docs(self):
"""Whether the task has a validation set"""
pass
@abc.abstractmethod
@abstractmethod
def has_test_docs(self):
"""Whether the task has a test set"""
pass
......@@ -223,15 +361,15 @@ class Task(abc.ABC):
return rnd.sample(self._training_docs, k)
@abc.abstractmethod
@abstractmethod
def doc_to_text(self, doc):
pass
@abc.abstractmethod
@abstractmethod
def doc_to_target(self, doc):
pass
@abc.abstractmethod
@abstractmethod
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
......@@ -245,7 +383,7 @@ class Task(abc.ABC):
"""
pass
@abc.abstractmethod
@abstractmethod
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
......@@ -258,7 +396,7 @@ class Task(abc.ABC):
"""
pass
@abc.abstractmethod
@abstractmethod
def aggregation(self):
"""
:returns: {str: [metric_score] -> float}
......@@ -267,7 +405,7 @@ class Task(abc.ABC):
"""
pass
@abc.abstractmethod
@abstractmethod
def higher_is_better(self):
"""
:returns: {str: bool}
......
......@@ -2,7 +2,7 @@ import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
from lm_eval.base import LM, TokenizedLM
from lm_eval.base import LM, BaseLM
from lm_eval import utils
from tqdm import tqdm
import numpy as np
......@@ -10,150 +10,7 @@ from abc import ABC, abstractmethod
from typing import Iterable
class TorchLM(TokenizedLM):
@abstractmethod
def _model_generate(self, context, max_length, eos_token_id):
pass
@abstractmethod
def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits retuned from the model
"""
pass
# subclass must implement properties batch_size, vocab_size, eot_token_id, max_gen_toks, device.
# TODO: enforce this somehow
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch padded context length.
# this is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = x[1] + x[2]
return (-len(toks), tuple(toks))
# TODO: automatic (variable) batch size detection for vectorization
reord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
inps = []
contlens = []
inplens = []
padding_length = None
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying
for _, context_enc, continuation_enc in chunk:
# sanity check
assert len(context_enc) > 0
assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.vocab_size] slice
# cont_toks 4 5 6 7 8 9
# when too long to fit in context, truncate from the left
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length+1):][:-1]
, dtype=torch.long).to(self.device)
inplen, = inp.shape
cont = continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length = padding_length if padding_length is not None else inplen
# pad to length
inp = torch.cat([
inp, # [seq]
torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq]
], dim=0)
inps.append(inp.unsqueeze(0))
contlens.append(cont)
inplens.append(inplen)
multi_logits = F.log_softmax(self._model_call(torch.cat(inps, dim=0)), dim=-1).cpu() # [batch, seq, vocab]
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(chunk, multi_logits, inps, inplens, contlens):
contlen = len(cont_toks)
logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab]
greedy_tokens = logits.argmax(dim=-1)
# cont_toks :: [1, seq]
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0)
max_equal = (greedy_tokens == cont_toks).all()
#last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]
answer = (float(logits.sum()), bool(max_equal))
# partial caching
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
res.append(answer)
return reord.get_original(res)
def greedy_until(self, requests):
# TODO: implement fully general `until` that handles untils that are
# multiple tokens or that span multiple tokens correctly
# TODO: extract to TokenizedLM?
res = []
def _collate(x):
toks = self.tok_encode(x[0])
return (len(toks), x[0])
reord = utils.Reorderer(requests, _collate)
for context, until in tqdm(reord.get_reordered()):
if isinstance(until, str): until = [until]
primary_until, = self.tok_encode(until[0])
context_enc = torch.tensor([self.tok_encode(context)[self.max_gen_toks - self.max_length:]]).to(self.device)
cont = self._model_generate(context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until)
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1]:])
for term in until:
s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s)
return reord.get_original(res)
class HFLM(TorchLM):
class HFLM(BaseLM):
def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1):
super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment