Commit b0cf0163 authored by Leo Gao's avatar Leo Gao
Browse files

Begin refactoring perplexity code

parent ee5467ff
......@@ -26,3 +26,11 @@ class DummyLM(LM):
assert ctx.strip() != ''
return res
def loglikelihood_perplexity(self, requests):
res = []
for _ in requests:
res.append(-random.random())
return res
\ No newline at end of file
......@@ -60,24 +60,23 @@ class GPT2LM(LM):
with torch.no_grad():
for string, in tqdm(requests):
encoded = self.tokenizer.encode_plus(string)["input_ids"]
rolling_token_windows = utils.get_rolling_token_windows(
rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
token_list=encoded,
prefix_token=self.EOT_TOKEN_ID,
max_seq_len=self.max_length,
context_len=1,
)
string_nll = []
for input_tokens, pred_tokens in rolling_token_windows:
inp = torch.tensor([input_tokens], dtype=torch.long).to(self.device)
labels = torch.tensor([pred_tokens], dtype=torch.long).to(self.device)
logits = F.log_softmax(self.gpt2(inp)[0][:, :, :self.VOCAB_SIZE], dim=-1) # [batch, seq, vocab]
# Only score the relevant logits (i.e. the last len(pred_tokens) logits
scoring_logits = logits[:, -len(pred_tokens):].reshape(len(pred_tokens), self.VOCAB_SIZE)
string_nll.append(
F.cross_entropy(scoring_logits, target=labels.view(-1), reduction="none").cpu().numpy()
)
string_nll = np.concatenate(string_nll)
loglikelihoods.append(-string_nll)
)))
# todo: figure out partial caching
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
string_nll = self._loglikelihood_tokens(rolling_token_windows)
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
......@@ -94,12 +93,25 @@ class GPT2LM(LM):
reord = utils.Reorderer(requests, _collate)
for cache_key, context_enc, continuation_enc in tqdm(reord.get_reordered()):
assert len(context_enc) > 0
assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.VOCAB_SIZE] slice
# cont_toks 4 5 6 7 8 9
# when too long to fit in context, truncate from the left
inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
inp = torch.tensor([
(context_enc + continuation_enc)[-(self.max_length+1):]
], dtype=torch.long).to(self.device)
cont_toks = inp[:, -len(continuation_enc):] # [batch, seq]
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0][:, :, :self.VOCAB_SIZE], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
logits = F.log_softmax(self.gpt2(inp[:, :-1])[0][:, -len(continuation_enc):, :self.VOCAB_SIZE], dim=-1) # [batch, seq, vocab] - vocab size is clipped to exclude padding tokens or whatever
greedy_tokens = logits.argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all()
......@@ -108,7 +120,7 @@ class GPT2LM(LM):
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
answer = (float(logits.sum()), bool(max_equal))
answer = (float(logits.cpu().to(torch.float64).sum()), bool(max_equal))
# partial caching
if cache_key is not None:
......
......@@ -97,12 +97,19 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
while predicted < len(token_list):
window_pred_len = min(len(token_list) - predicted, pred_len)
window_end = predicted + window_pred_len
yield (
token_list[window_end - max_seq_len - 1:window_end - 1],
token_list[window_end - window_pred_len:window_end],
)
predicted += window_pred_len
def make_disjoint_window(pair):
""" Takes output from get_rolling_token_windows and makes the context not overlap with the continuation """
a, b = pair
return a[:-(len(b) - 1)], b
class Reorderer:
def __init__(self, arr, fn):
......
......@@ -15,6 +15,7 @@ def test_evaluator(taskname, Task):
def ll_fn(reqs):
for ctx, cont in reqs:
if len(ctx) == 0: continue
# space convention
assert ctx[-1] != ' '
assert cont[0] == ' ' or ctx[-1] == '\n'
......
......@@ -41,13 +41,11 @@ def test_gpt2_perplexity():
gpt2 = models.get_model('gpt2').create_from_arg_string("device=cpu")
test_string = "We study empirical scaling laws for language model performance on the cross-entropy loss."
perplexity = gpt2.loglikelihood_perplexity([(test_string,)])[0]
targets = [-4.9599953, -8.069298, -8.308624, -10.178513, -8.906924, -1.9318912, -7.745445, -7.146077, -5.2072, -3.5882986, -1.9957212, -8.044922, -0.20841774, -5.1096807, -0.099879116, -8.888423, -4.6180487]
for pred, tgt in zip(perplexity, targets):
assert pred == pytest.approx(tgt)
tgt = sum([-4.9599953, -8.069298, -8.308624, -10.178513, -8.906924, -1.9318912, -7.745445, -7.146077, -5.2072, -3.5882986, -1.9957212, -8.044922, -0.20841774, -5.1096807, -0.099879116, -8.888423, -4.6180487])
assert perplexity == pytest.approx(tgt, abs=1e-3)
# Hack: modify gpt2 to have shorter context length to induce rolling windows
gpt2.max_length = 5
perplexity = gpt2.loglikelihood_perplexity([(test_string,)])[0]
targets = [-4.96001, -8.069275, -8.308612, -10.178482, -8.90691, -4.037338, -8.09261, -11.662385, -10.206891, -4.425003, -2.2563353, -7.909143, -1.9304147, -7.3610134, -2.3120654, -7.3229, -2.1643813]
for pred, tgt in zip(perplexity, targets):
assert pred == pytest.approx(tgt)
tgt = sum([-4.96001, -8.069275, -8.308612, -10.178482, -8.90691, -4.037338, -8.09261, -11.662385, -10.206891, -4.425003, -2.2563353, -7.909143, -1.9304147, -7.3610134, -2.3120654, -7.3229, -2.1643813])
assert perplexity == pytest.approx(tgt, abs=1e-3)
from lm_eval.utils import get_rolling_token_windows
from lm_eval.utils import get_rolling_token_windows, make_disjoint_window
# noinspection DuplicatedCode
......@@ -200,3 +200,8 @@ def test_get_rolling_token_windows_empty():
for _ in generator:
n += 1
assert n == 0
def test_make_disjoint_window():
assert make_disjoint_window(([1,2,3,4,5], [2,3,4,5,6])) == ([1], [2,3,4,5,6])
assert make_disjoint_window(([1,2,3,4,5], [4,5,6])) == ([1,2,3], [4,5,6])
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment