".github/vscode:/vscode.git/clone" did not exist on "19e5a890f70b95a55c9de6a55357d78fc0a4ff81"
Commit 1e7f884d authored by Leo Gao's avatar Leo Gao
Browse files

Refactor PerplexityTask

parent b0cf0163
import abc
import random
import numpy as np
import re
from lm_eval.metrics import mean
from lm_eval.metrics import mean, perplexity, weighted_mean
class LM(abc.ABC):
......@@ -307,14 +308,17 @@ class PerplexityTask(Task, abc.ABC):
return ""
def higher_is_better(self):
return False
return {
"word_perplexity": False,
"byte_perplexity": False,
"bits_per_byte": False,
}
def doc_to_text(self, doc):
return doc
def doc_to_target(self, doc):
raise NotImplementedError()
return doc
def construct_requests(self, doc, ctx):
assert not ctx
......@@ -324,20 +328,26 @@ class PerplexityTask(Task, abc.ABC):
def process_results(self, doc, results):
loglikelihood, = results
return {
"perplexity": loglikelihood,
"word_perplexity": loglikelihood / self.count_words(self.doc_to_text(doc)),
"byte_perplexity": loglikelihood / self.count_bytes(self.doc_to_text(doc)),
"bits_per_byte": (-loglikelihood, self.count_bytes(self.doc_to_text(doc)))
}
def aggregation(self):
return {
"perplexity": self.compute_perplexity_from_loglikelihood,
"word_perplexity": perplexity,
"byte_perplexity": perplexity,
"bits_per_byte": weighted_mean
}
@classmethod
def compute_perplexity_from_loglikelihood(cls, loglikelihoods):
aggregate_logprobs = np.concatenate(loglikelihoods)
perplexity = np.exp(-aggregate_logprobs.mean())
return float(perplexity)
def count_bytes(self, s):
return len(s.encode("utf-8"))
def count_words(self, s):
""" Downstream tasks with custom word boundaries should override this! """
return len(re.split(r"\s+", s))
def
req_ret_lens = {
'loglikelihood': 2,
......
......@@ -62,6 +62,11 @@ def perplexity(items):
return math.exp(-mean(items))
def weighted_mean(items):
a, b = zip(*items)
return sum(a) / sum(b)
def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
......
......@@ -60,6 +60,7 @@ class GPT2LM(LM):
with torch.no_grad():
for string, in tqdm(requests):
encoded = self.tokenizer.encode_plus(string)["input_ids"]
rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
token_list=encoded,
prefix_token=self.EOT_TOKEN_ID,
......@@ -67,9 +68,9 @@ class GPT2LM(LM):
context_len=1,
)))
# todo: figure out partial caching
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for that
string_nll = self._loglikelihood_tokens(rolling_token_windows)
# discard is_greedy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment