Commit 9454c839 authored by Jason Phang's avatar Jason Phang
Browse files

gpt2 perplexity

parent 8846bec0
......@@ -27,9 +27,51 @@ class LM(abc.ABC):
:return: list
A list of pairs (logprob, isgreedy)
logprob: float
The log probability of `contination`
The log probability of `continuation`
isgreedy:
Whether `contination` would be generated by greedy sampling from `context`
Whether `continuation` would be generated by greedy sampling from `context`
"""
pass
@abc.abstractmethod
def loglikelihood_perplexity(self, requests):
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
the max context length.
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementaitons
which may simply concatenate multiple documents together.
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
multiple chunks, the last input will still a full-sized context.
Example:
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
Prefix: EOT
Max context length: 4
Resulting input/prediction pairs:
INPUT: EOT 0 1 2
PRED: 0 1 2 3
INPUT: 3 4 5 6
PRED: 4 5 6 7
INPUT: 5 6 7 8
PRED: 8 9
Observe that:
1. Each token is predicted exactly once
2. For the last pair, we provide the full context, but only score the last two tokens
:param requests: list
A list of strings
string: str
String for which we are computing per-toke loglikelihood
:return: list
A list of pairs (logprob, isgreedy)
logprob: float
The log probability of `continuation`
isgreedy:
Whether `continuation` would be generated by greedy sampling from `context`
"""
pass
......@@ -247,9 +289,60 @@ class MultipleChoiceTask(Task):
}
class PerplexityTask(Task, abc.ABC):
def has_training_docs(self):
return False
def fewshot_description(self):
return ""
def fewshot_examples(self, k, rnd):
assert k == 0
return []
def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
assert num_fewshot == 0
assert not provide_description
return ""
def higher_is_better(self):
return False
def doc_to_text(self, doc):
return doc
def doc_to_target(self, doc):
raise NotImplementedError()
return doc
def construct_requests(self, doc, ctx):
assert not ctx
req = rf.loglikelihood_perplexity(doc)
return req
def process_results(self, doc, results):
loglikelihood, = results
return {
"perplexity": loglikelihood,
}
def aggregation(self):
return {
"perplexity": self.compute_perplexity_from_loglikelihood,
}
@classmethod
def compute_perplexity_from_loglikelihood(cls, loglikelihoods):
aggregate_logprobs = np.concatenate(loglikelihoods)
perplexity = np.exp(-aggregate_logprobs.mean())
return float(perplexity)
req_ret_lens = {
'loglikelihood': 2,
'greedy_until': None,
'loglikelihood_perplexity': 1,
}
import os
......
......@@ -34,7 +34,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
task_docs = list(task_doc_func())
rnd = random.Random()
rnd.seed(42)
rnd.shuffle(task_docs)
# rnd.shuffle(task_docs)
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
docs[(task_name, doc_id)] = doc
......
......@@ -4,10 +4,13 @@ import torch.nn.functional as F
from lm_eval.base import LM
from lm_eval import utils
from tqdm import tqdm
import numpy as np
class GPT2LM(LM):
MAX_GEN_TOKS = 256
VOCAB_SIZE = 50257
EOT_TOKEN_ID = 50256
def __init__(self, device=None, pretrained='gpt2'):
super().__init__()
......@@ -39,7 +42,7 @@ class GPT2LM(LM):
for context, continuation in requests:
if context == "":
# end of text as context
context_enc = [50256]
context_enc = [self.EOT_TOKEN_ID]
else:
context_enc = self.tokenizer.encode(context)
......@@ -49,6 +52,35 @@ class GPT2LM(LM):
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_perplexity(self, requests):
# TODO: Implement caching once we've confirmed the perplexity implementation
# TODO: automatic batch size detection for vectorization
loglikelihoods = []
with torch.no_grad():
for string, in tqdm(requests):
encoded = self.tokenizer.encode_plus(string)["input_ids"]
rolling_token_windows = utils.get_rolling_token_windows(
token_list=encoded,
prefix_token=self.EOT_TOKEN_ID,
max_seq_len=self.max_length,
context_len=1,
)
string_nll = []
for input_tokens, pred_tokens in rolling_token_windows:
inp = torch.tensor([input_tokens], dtype=torch.long).to(self.device)
labels = torch.tensor([pred_tokens], dtype=torch.long).to(self.device)
logits = F.log_softmax(self.gpt2(inp)[0][:, :, :self.VOCAB_SIZE], dim=-1) # [batch, seq, vocab]
# Only score the relevant logits (i.e. the last len(pred_tokens) logits
scoring_logits = logits[:, -len(pred_tokens):].reshape(len(pred_tokens), self.VOCAB_SIZE)
string_nll.append(
F.cross_entropy(scoring_logits, target=labels.view(-1), reduction="none").cpu().numpy()
)
string_nll = np.concatenate(string_nll)
loglikelihoods.append(-string_nll)
return loglikelihoods
def _loglikelihood_tokens(self, requests):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
......@@ -59,7 +91,7 @@ class GPT2LM(LM):
def _collate(x):
toks = x[1] + x[2]
return (len(toks), tuple(toks))
reord = utils.Reorderer(requests, _collate)
for cache_key, context_enc, continuation_enc in tqdm(reord.get_reordered()):
# when too long to fit in context, truncate from the left
......@@ -67,7 +99,7 @@ class GPT2LM(LM):
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0][:, :, :50257], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
logits = F.log_softmax(self.gpt2(inp)[0][:, :, :self.VOCAB_SIZE], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
greedy_tokens = logits.argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all()
......
......@@ -92,9 +92,9 @@ class GPT3LM(LM):
# we care about and so we need some kind of backup for when it isn't
toks = x[1] + x[2]
return (len(toks), tuple(toks))
reord = utils.Reorderer(requests, _collate)
for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
inps = []
ctxlens = []
......@@ -121,7 +121,7 @@ class GPT3LM(LM):
# partial caching
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return reord.get_original(res)
def greedy_until(self, requests):
......
......@@ -37,6 +37,7 @@ from . import hendrycks_test
from . import hendrycks_math
from . import cbt
from . import lambada_cloze
from . import pile
########################################
# Translation tasks
......@@ -171,6 +172,10 @@ TASK_REGISTRY = {
"cycle_letters": unscramble.CycleLetters,
"random_insertion": unscramble.RandomInsertion,
"reversed_words": unscramble.ReversedWords,
# Pile
"pile_enron": pile.PileEnronPerplexityTask,
"pile_ubuntu": pile.PileUbuntuPerplexityTask,
}
......
import os
import lm_dataformat
import abc
import numpy as np
from lm_eval.base import rf, PerplexityTask
from ..metrics import mean, matthews_corrcoef, f1_score
from ..utils import general_detokenize
from best_download import download_file
class PilePerplexityTask(PerplexityTask, abc.ABC):
PILE_SET_NAME = None
VAL_PATH = 'data/pile/val.jsonl.zst'
TEST_PATH = 'data/pile/test.jsonl.zst'
def download(self):
os.makedirs("data/pile/", exist_ok=True)
if not os.path.exists(self.VAL_PATH):
download_file("https://the-eye.eu/public/AI/pile/val.jsonl.zst", self.VAL_PATH)
if not os.path.exists(self.TEST_PATH):
download_file("https://the-eye.eu/public/AI/pile/test.jsonl.zst", self.TEST_PATH)
def validation_docs(self):
rdr = lm_dataformat.Reader(self.VAL_PATH)
for doc, metadata in rdr.stream_data(get_meta=True):
if metadata["pile_set_name"] == self.PILE_SET_NAME:
yield doc
def test_docs(self):
rdr = lm_dataformat.Reader(self.TEST_PATH)
for doc, metadata in rdr.stream_data(get_meta=True):
if metadata["pile_set_name"] == self.PILE_SET_NAME:
yield doc
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
class PileEnronPerplexityTask(PilePerplexityTask):
PILE_SET_NAME = "Enron Emails"
class PileUbuntuPerplexityTask(PilePerplexityTask):
PILE_SET_NAME = "Ubuntu IRC"
......@@ -61,6 +61,49 @@ def general_detokenize(string):
return string
def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len):
"""
- context_len allows for a rolling window context, allowing each prediction window to potentially
condition on some context
:param token_list: list
List of tokens to be PREDICTED
:param max_seq_len: int
max_seq_len of model (or max_seq_len we want to use)
:param context_len: int
Amount of desired token context for prediction. Needs to be at least 1.
:param prefix_token: token
Dummy token like <eos> so the first token has something to condition on
:return: generator
Generator of tuples
(input_tokens, pred_tokens)
Note: Score only the last len(pred_tokens) logits of the LM
"""
assert 1 <= context_len <= max_seq_len
if not token_list:
return
# +1 offset, going from input->preds
pred_len = max_seq_len - context_len + 1
predicted = 0
# Special handling for first window: predict all tokens
first_seq_len = min(max_seq_len, len(token_list))
yield (
[prefix_token] + token_list[:first_seq_len - 1],
token_list[:first_seq_len]
)
predicted += first_seq_len
while predicted < len(token_list):
window_pred_len = min(len(token_list) - predicted, pred_len)
window_end = predicted + window_pred_len
yield (
token_list[window_end - max_seq_len - 1:window_end - 1],
token_list[window_end - window_pred_len:window_end],
)
predicted += window_pred_len
class Reorderer:
def __init__(self, arr, fn):
self.size = len(arr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment