Commit 1e7f884d authored by Leo Gao's avatar Leo Gao
Browse files

Refactor PerplexityTask

parent b0cf0163
import abc
import random
import numpy as np
import re
from lm_eval.metrics import mean
from lm_eval.metrics import mean, perplexity, weighted_mean
class LM(abc.ABC):
......@@ -307,14 +308,17 @@ class PerplexityTask(Task, abc.ABC):
return ""
def higher_is_better(self):
return False
return {
"word_perplexity": False,
"byte_perplexity": False,
"bits_per_byte": False,
}
def doc_to_text(self, doc):
return doc
def doc_to_target(self, doc):
raise NotImplementedError()
return doc
def construct_requests(self, doc, ctx):
assert not ctx
......@@ -324,20 +328,26 @@ class PerplexityTask(Task, abc.ABC):
def process_results(self, doc, results):
loglikelihood, = results
return {
"perplexity": loglikelihood,
"word_perplexity": loglikelihood / self.count_words(self.doc_to_text(doc)),
"byte_perplexity": loglikelihood / self.count_bytes(self.doc_to_text(doc)),
"bits_per_byte": (-loglikelihood, self.count_bytes(self.doc_to_text(doc)))
}
def aggregation(self):
return {
"perplexity": self.compute_perplexity_from_loglikelihood,
"word_perplexity": perplexity,
"byte_perplexity": perplexity,
"bits_per_byte": weighted_mean
}
@classmethod
def compute_perplexity_from_loglikelihood(cls, loglikelihoods):
aggregate_logprobs = np.concatenate(loglikelihoods)
perplexity = np.exp(-aggregate_logprobs.mean())
return float(perplexity)
def count_bytes(self, s):
return len(s.encode("utf-8"))
def count_words(self, s):
""" Downstream tasks with custom word boundaries should override this! """
return len(re.split(r"\s+", s))
def
req_ret_lens = {
'loglikelihood': 2,
......
......@@ -62,6 +62,11 @@ def perplexity(items):
return math.exp(-mean(items))
def weighted_mean(items):
a, b = zip(*items)
return sum(a) / sum(b)
def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
......
......@@ -60,6 +60,7 @@ class GPT2LM(LM):
with torch.no_grad():
for string, in tqdm(requests):
encoded = self.tokenizer.encode_plus(string)["input_ids"]
rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
token_list=encoded,
prefix_token=self.EOT_TOKEN_ID,
......@@ -67,9 +68,9 @@ class GPT2LM(LM):
context_len=1,
)))
# todo: figure out partial caching
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for that
string_nll = self._loglikelihood_tokens(rolling_token_windows)
# discard is_greedy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment