Commit 1e7f884d authored by Leo Gao's avatar Leo Gao
Browse files

Refactor PerplexityTask

parent b0cf0163
import abc import abc
import random import random
import numpy as np import numpy as np
import re
from lm_eval.metrics import mean from lm_eval.metrics import mean, perplexity, weighted_mean
class LM(abc.ABC): class LM(abc.ABC):
...@@ -307,14 +308,17 @@ class PerplexityTask(Task, abc.ABC): ...@@ -307,14 +308,17 @@ class PerplexityTask(Task, abc.ABC):
return "" return ""
def higher_is_better(self): def higher_is_better(self):
return False return {
"word_perplexity": False,
"byte_perplexity": False,
"bits_per_byte": False,
}
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc return doc
def doc_to_target(self, doc): def doc_to_target(self, doc):
raise NotImplementedError() raise NotImplementedError()
return doc
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
assert not ctx assert not ctx
...@@ -324,20 +328,26 @@ class PerplexityTask(Task, abc.ABC): ...@@ -324,20 +328,26 @@ class PerplexityTask(Task, abc.ABC):
def process_results(self, doc, results): def process_results(self, doc, results):
loglikelihood, = results loglikelihood, = results
return { return {
"perplexity": loglikelihood, "word_perplexity": loglikelihood / self.count_words(self.doc_to_text(doc)),
"byte_perplexity": loglikelihood / self.count_bytes(self.doc_to_text(doc)),
"bits_per_byte": (-loglikelihood, self.count_bytes(self.doc_to_text(doc)))
} }
def aggregation(self): def aggregation(self):
return { return {
"perplexity": self.compute_perplexity_from_loglikelihood, "word_perplexity": perplexity,
"byte_perplexity": perplexity,
"bits_per_byte": weighted_mean
} }
@classmethod def count_bytes(self, s):
def compute_perplexity_from_loglikelihood(cls, loglikelihoods): return len(s.encode("utf-8"))
aggregate_logprobs = np.concatenate(loglikelihoods)
perplexity = np.exp(-aggregate_logprobs.mean()) def count_words(self, s):
return float(perplexity) """ Downstream tasks with custom word boundaries should override this! """
return len(re.split(r"\s+", s))
def
req_ret_lens = { req_ret_lens = {
'loglikelihood': 2, 'loglikelihood': 2,
......
...@@ -62,6 +62,11 @@ def perplexity(items): ...@@ -62,6 +62,11 @@ def perplexity(items):
return math.exp(-mean(items)) return math.exp(-mean(items))
def weighted_mean(items):
a, b = zip(*items)
return sum(a) / sum(b)
def bleu(items): def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching for evaluating a generated sentence to a reference sentence. It counts matching
......
...@@ -60,6 +60,7 @@ class GPT2LM(LM): ...@@ -60,6 +60,7 @@ class GPT2LM(LM):
with torch.no_grad(): with torch.no_grad():
for string, in tqdm(requests): for string, in tqdm(requests):
encoded = self.tokenizer.encode_plus(string)["input_ids"] encoded = self.tokenizer.encode_plus(string)["input_ids"]
rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows( rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
token_list=encoded, token_list=encoded,
prefix_token=self.EOT_TOKEN_ID, prefix_token=self.EOT_TOKEN_ID,
...@@ -67,9 +68,9 @@ class GPT2LM(LM): ...@@ -67,9 +68,9 @@ class GPT2LM(LM):
context_len=1, context_len=1,
))) )))
# todo: figure out partial caching
rolling_token_windows = [(None,) + x for x in rolling_token_windows] rolling_token_windows = [(None,) + x for x in rolling_token_windows]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for that
string_nll = self._loglikelihood_tokens(rolling_token_windows) string_nll = self._loglikelihood_tokens(rolling_token_windows)
# discard is_greedy # discard is_greedy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment