Commit b57d059a authored by Leo Gao's avatar Leo Gao
Browse files

Fix lambada

parent 8eada53b
......@@ -2,6 +2,7 @@ import abc
import random
import numpy as np
import sklearn
import math
class LM(abc.ABC):
......@@ -229,6 +230,9 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(scores_for_ground_truths)
def perplexity(items):
return math.exp(-mean(items))
req_ret_lens = {
'loglikelihood': 2
}
......
from lm_eval.base import Task, rf, mean
from lm_eval.base import Task, rf, mean, perplexity
from lm_eval.utils import sh
import json
import math
......@@ -45,21 +45,23 @@ class LAMBADA(Task):
return ""
def construct_requests(self, doc, ctx):
ll, is_greedy = rf.loglikelihood(doc, self.doc_to_target(doc))
ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc))
return ll, is_greedy
def process_results(self, doc, results):
ll, is_greedy = results
print(ll)
return {
'perplexity': math.exp(-ll),
'perplexity': ll,
'accuracy': int(is_greedy)
}
def aggregation(self):
return {
'perplexity': mean,
'perplexity': perplexity,
'accuracy': mean
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment