Commit fea936e6 authored by Leo Gao's avatar Leo Gao
Browse files

Merge TokenizedLM and TorchLM into BaseLM

parent 7f24a08b
...@@ -4,16 +4,17 @@ from typing import Iterable ...@@ -4,16 +4,17 @@ from typing import Iterable
import numpy as np import numpy as np
import re import re
from tqdm import tqdm from tqdm import tqdm
import torch
from lm_eval.metrics import mean, perplexity, weighted_perplexity, weighted_mean from lm_eval.metrics import mean, perplexity, weighted_perplexity, weighted_mean
from lm_eval import utils from lm_eval import utils
from abc import abstractmethod
class LM(abc.ABC): class LM(abc.ABC):
def __init__(self): def __init__(self):
self.cache_hook = CacheHook(None) self.cache_hook = CacheHook(None)
@abc.abstractmethod @abstractmethod
def loglikelihood(self, requests): def loglikelihood(self, requests):
"""Compute log-likelihood of generating a continuation from a context. """Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other Downstream tasks should attempt to use loglikelihood instead of other
...@@ -37,7 +38,7 @@ class LM(abc.ABC): ...@@ -37,7 +38,7 @@ class LM(abc.ABC):
""" """
pass pass
@abc.abstractmethod @abstractmethod
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests):
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation """Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model. - We will use the full max context length of the model.
...@@ -80,7 +81,7 @@ class LM(abc.ABC): ...@@ -80,7 +81,7 @@ class LM(abc.ABC):
pass pass
# TODO: Add an optional max length # TODO: Add an optional max length
@abc.abstractmethod @abstractmethod
def greedy_until(self, requests): def greedy_until(self, requests):
"""Generate greedily until a stopping sequence """Generate greedily until a stopping sequence
...@@ -108,15 +109,26 @@ class LM(abc.ABC): ...@@ -108,15 +109,26 @@ class LM(abc.ABC):
self.cache_hook = cache_hook self.cache_hook = cache_hook
class TokenizedLM(LM): class BaseLM(LM):
@abc.abstractmethod @abstractmethod
def tok_encode(self, string: str): pass def tok_encode(self, string: str): pass
@abc.abstractmethod @abstractmethod
def tok_decode(self, tokens: Iterable[int]): pass def tok_decode(self, tokens: Iterable[int]): pass
@abc.abstractmethod @abstractmethod
def _loglikelihood_tokens(self, requests, disable_tqdm=False): pass def _model_generate(self, context, max_length, eos_token_id): pass
@abstractmethod
def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits retuned from the model
"""
pass
# subclass must implement properties vocab_size, eot_token_id, max_gen_toks. # subclass must implement properties vocab_size, eot_token_id, max_gen_toks.
# TODO: enforce this somehow # TODO: enforce this somehow
...@@ -162,6 +174,132 @@ class TokenizedLM(LM): ...@@ -162,6 +174,132 @@ class TokenizedLM(LM):
return loglikelihoods return loglikelihoods
# subclass must implement properties batch_size, vocab_size, eot_token_id, max_gen_toks, device.
# TODO: enforce this somehow
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch padded context length.
# this is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = x[1] + x[2]
return (-len(toks), tuple(toks))
# TODO: automatic (variable) batch size detection for vectorization
reord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
inps = []
contlens = []
inplens = []
padding_length = None
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying
for _, context_enc, continuation_enc in chunk:
# sanity check
assert len(context_enc) > 0
assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.vocab_size] slice
# cont_toks 4 5 6 7 8 9
# when too long to fit in context, truncate from the left
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length+1):][:-1]
, dtype=torch.long).to(self.device)
inplen, = inp.shape
cont = continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length = padding_length if padding_length is not None else inplen
# pad to length
inp = torch.cat([
inp, # [seq]
torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq]
], dim=0)
inps.append(inp.unsqueeze(0))
contlens.append(cont)
inplens.append(inplen)
multi_logits = F.log_softmax(self._model_call(torch.cat(inps, dim=0)), dim=-1).cpu() # [batch, seq, vocab]
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(chunk, multi_logits, inps, inplens, contlens):
contlen = len(cont_toks)
logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab]
greedy_tokens = logits.argmax(dim=-1)
# cont_toks :: [1, seq]
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0)
max_equal = (greedy_tokens == cont_toks).all()
#last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]
answer = (float(logits.sum()), bool(max_equal))
# partial caching
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
res.append(answer)
return reord.get_original(res)
def greedy_until(self, requests):
# TODO: implement fully general `until` that handles untils that are
# multiple tokens or that span multiple tokens correctly
# TODO: extract to TokenizedLM?
res = []
def _collate(x):
toks = self.tok_encode(x[0])
return (len(toks), x[0])
reord = utils.Reorderer(requests, _collate)
for context, until in tqdm(reord.get_reordered()):
if isinstance(until, str): until = [until]
primary_until, = self.tok_encode(until[0])
context_enc = torch.tensor([self.tok_encode(context)[self.max_gen_toks - self.max_length:]]).to(self.device)
cont = self._model_generate(context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until)
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1]:])
for term in until:
s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s)
return reord.get_original(res)
class Task(abc.ABC): class Task(abc.ABC):
"""A task represents an entire benchmark including its dataset, problems, """A task represents an entire benchmark including its dataset, problems,
...@@ -181,17 +319,17 @@ class Task(abc.ABC): ...@@ -181,17 +319,17 @@ class Task(abc.ABC):
"""Downloads the task dataset if necessary""" """Downloads the task dataset if necessary"""
pass pass
@abc.abstractmethod @abstractmethod
def has_training_docs(self): def has_training_docs(self):
"""Whether the task has a training set""" """Whether the task has a training set"""
pass pass
@abc.abstractmethod @abstractmethod
def has_validation_docs(self): def has_validation_docs(self):
"""Whether the task has a validation set""" """Whether the task has a validation set"""
pass pass
@abc.abstractmethod @abstractmethod
def has_test_docs(self): def has_test_docs(self):
"""Whether the task has a test set""" """Whether the task has a test set"""
pass pass
...@@ -223,15 +361,15 @@ class Task(abc.ABC): ...@@ -223,15 +361,15 @@ class Task(abc.ABC):
return rnd.sample(self._training_docs, k) return rnd.sample(self._training_docs, k)
@abc.abstractmethod @abstractmethod
def doc_to_text(self, doc): def doc_to_text(self, doc):
pass pass
@abc.abstractmethod @abstractmethod
def doc_to_target(self, doc): def doc_to_target(self, doc):
pass pass
@abc.abstractmethod @abstractmethod
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
...@@ -245,7 +383,7 @@ class Task(abc.ABC): ...@@ -245,7 +383,7 @@ class Task(abc.ABC):
""" """
pass pass
@abc.abstractmethod @abstractmethod
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
...@@ -258,7 +396,7 @@ class Task(abc.ABC): ...@@ -258,7 +396,7 @@ class Task(abc.ABC):
""" """
pass pass
@abc.abstractmethod @abstractmethod
def aggregation(self): def aggregation(self):
""" """
:returns: {str: [metric_score] -> float} :returns: {str: [metric_score] -> float}
...@@ -267,7 +405,7 @@ class Task(abc.ABC): ...@@ -267,7 +405,7 @@ class Task(abc.ABC):
""" """
pass pass
@abc.abstractmethod @abstractmethod
def higher_is_better(self): def higher_is_better(self):
""" """
:returns: {str: bool} :returns: {str: bool}
......
...@@ -2,7 +2,7 @@ import transformers ...@@ -2,7 +2,7 @@ import transformers
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from lm_eval.base import LM, TokenizedLM from lm_eval.base import LM, BaseLM
from lm_eval import utils from lm_eval import utils
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
...@@ -10,150 +10,7 @@ from abc import ABC, abstractmethod ...@@ -10,150 +10,7 @@ from abc import ABC, abstractmethod
from typing import Iterable from typing import Iterable
class TorchLM(TokenizedLM): class HFLM(BaseLM):
@abstractmethod
def _model_generate(self, context, max_length, eos_token_id):
pass
@abstractmethod
def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits retuned from the model
"""
pass
# subclass must implement properties batch_size, vocab_size, eot_token_id, max_gen_toks, device.
# TODO: enforce this somehow
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch padded context length.
# this is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = x[1] + x[2]
return (-len(toks), tuple(toks))
# TODO: automatic (variable) batch size detection for vectorization
reord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
inps = []
contlens = []
inplens = []
padding_length = None
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying
for _, context_enc, continuation_enc in chunk:
# sanity check
assert len(context_enc) > 0
assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.vocab_size] slice
# cont_toks 4 5 6 7 8 9
# when too long to fit in context, truncate from the left
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length+1):][:-1]
, dtype=torch.long).to(self.device)
inplen, = inp.shape
cont = continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length = padding_length if padding_length is not None else inplen
# pad to length
inp = torch.cat([
inp, # [seq]
torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq]
], dim=0)
inps.append(inp.unsqueeze(0))
contlens.append(cont)
inplens.append(inplen)
multi_logits = F.log_softmax(self._model_call(torch.cat(inps, dim=0)), dim=-1).cpu() # [batch, seq, vocab]
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(chunk, multi_logits, inps, inplens, contlens):
contlen = len(cont_toks)
logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab]
greedy_tokens = logits.argmax(dim=-1)
# cont_toks :: [1, seq]
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0)
max_equal = (greedy_tokens == cont_toks).all()
#last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]
answer = (float(logits.sum()), bool(max_equal))
# partial caching
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
res.append(answer)
return reord.get_original(res)
def greedy_until(self, requests):
# TODO: implement fully general `until` that handles untils that are
# multiple tokens or that span multiple tokens correctly
# TODO: extract to TokenizedLM?
res = []
def _collate(x):
toks = self.tok_encode(x[0])
return (len(toks), x[0])
reord = utils.Reorderer(requests, _collate)
for context, until in tqdm(reord.get_reordered()):
if isinstance(until, str): until = [until]
primary_until, = self.tok_encode(until[0])
context_enc = torch.tensor([self.tok_encode(context)[self.max_gen_toks - self.max_length:]]).to(self.device)
cont = self._model_generate(context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until)
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1]:])
for term in until:
s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s)
return reord.get_original(res)
class HFLM(TorchLM):
def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1): def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1):
super().__init__() super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment