Commit 5aa601f3 authored by Leo Gao's avatar Leo Gao
Browse files

Merge branch 'master' of github.com:EleutherAI/lm_evaluation_harness

parents 42659c34 f984c88e
...@@ -48,7 +48,6 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit): ...@@ -48,7 +48,6 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
reqs = task.construct_requests(doc, ctx) reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)): reqs = [reqs] if not isinstance(reqs, (list, tuple)): reqs = [reqs]
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
requests[req.type].append(req) requests[req.type].append(req)
# i: index in requests for a single task instance # i: index in requests for a single task instance
...@@ -90,4 +89,4 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit): ...@@ -90,4 +89,4 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
task = task_dict[task_name] task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items) results[task_name][metric] = task.aggregation()[metric](items)
return results return results
\ No newline at end of file
...@@ -109,7 +109,7 @@ TASK_REGISTRY = { ...@@ -109,7 +109,7 @@ TASK_REGISTRY = {
"hellaswag": hellaswag.HellaSwag, # not implemented yet "hellaswag": hellaswag.HellaSwag, # not implemented yet
"openbookqa": openbookqa.OpenBookQA, "openbookqa": openbookqa.OpenBookQA,
# "sat": sat.SATAnalogies, # not implemented yet # "sat": sat.SATAnalogies, # not implemented yet
# "squad": squad.SQuAD, # not implemented yet "squad2": squad.SQuAD2,
"race": race.RACE, "race": race.RACE,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet # "naturalqs": naturalqs.NaturalQs, # not implemented yet
"headqa": headqa.HeadQA, "headqa": headqa.HeadQA,
......
import datasets
from math import exp
from lm_eval.base import rf
from lm_eval.metrics import f1_score, mean
from . common import HFTask from . common import HFTask
from functools import partial
class SQuAD(HFTask): def _squad_metric(predictions, references):
squad_metric = datasets.load_metric("squad_v2")
return squad_metric.compute(predictions=predictions, references=references)
def _squad_agg(key, items):
predictions, references = zip(*items)
return _squad_metric(predictions=predictions, references=references)[key]
class SQuAD2(HFTask):
DATASET_PATH = "squad_v2" DATASET_PATH = "squad_v2"
DATASET_NAME = None DATASET_NAME = None
...@@ -15,16 +31,14 @@ class SQuAD(HFTask): ...@@ -15,16 +31,14 @@ class SQuAD(HFTask):
return False return False
def training_docs(self): def training_docs(self):
if self.has_training_docs(): return self.data["train"]
return self.data["train"]
def validation_docs(self): def validation_docs(self):
if self.has_validation_docs(): return self.data["validation"]
return self.data["validation"]
def fewshot_description(self): def fewshot_description(self):
# TODO: redo description # TODO: figure out description
return "Title: The_Title_of_It\n\nBackground: A text passage as background to answer the question with.\n\nQ: Question about the passage.\n\nA: Answer." return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:' return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:'
...@@ -35,7 +49,7 @@ class SQuAD(HFTask): ...@@ -35,7 +49,7 @@ class SQuAD(HFTask):
answer = answer_list[0] answer = answer_list[0]
else: else:
answer = 'unanswerable' answer = 'unanswerable'
return answer return " " + answer
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
...@@ -48,8 +62,9 @@ class SQuAD(HFTask): ...@@ -48,8 +62,9 @@ class SQuAD(HFTask):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. continuation = rf.greedy_until(ctx, ['\n'])
raise NotImplementedError('Evaluation not implemented') is_unanswerable = rf.loglikelihood(ctx, " " + "unanswerable")
return continuation, is_unanswerable
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -61,8 +76,31 @@ class SQuAD(HFTask): ...@@ -61,8 +76,31 @@ class SQuAD(HFTask):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# TODO: implement evaluation. continuation, (logprob_unanswerable, _) = results
raise NotImplementedError('Evaluation not implemented')
no_answer_probability = exp(logprob_unanswerable)
predictions = {
'id': doc['id'],
'prediction_text': continuation,
'no_answer_probability': no_answer_probability,
}
references = {
'id': doc['id'],
'answers': doc['answers'],
}
return {
'exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
'HasAns_exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
'NoAns_exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
'best_exact': (predictions, references), # Best exact match (with varying threshold)
'best_f1': (predictions, references), # Best F1 (with varying threshold)
}
def aggregation(self): def aggregation(self):
""" """
...@@ -70,8 +108,16 @@ class SQuAD(HFTask): ...@@ -70,8 +108,16 @@ class SQuAD(HFTask):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') 'exact': partial(_squad_agg, 'exact'), # Exact match (the normalized answer exactly match the gold answer)
'f1': partial(_squad_agg, 'f1'), # The F-score of predicted tokens versus the gold answer
'HasAns_exact': partial(_squad_agg, 'HasAns_exact'), # Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': partial(_squad_agg, 'HasAns_f1'), # The F-score of predicted tokens versus the gold answer
'NoAns_exact': partial(_squad_agg, 'NoAns_exact'), # Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': partial(_squad_agg, 'NoAns_f1'), # The F-score of predicted tokens versus the gold answer
'best_exact': partial(_squad_agg, 'best_exact'), # Best exact match (with varying threshold)
'best_f1': partial(_squad_agg, 'best_f1'), # Best F1 (with varying threshold)
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -79,5 +125,13 @@ class SQuAD(HFTask): ...@@ -79,5 +125,13 @@ class SQuAD(HFTask):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') 'exact': True, # Exact match (the normalized answer exactly match the gold answer)
'f1': True, # The F-score of predicted tokens versus the gold answer
'HasAns_exact': True, # Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': True, # The F-score of predicted tokens versus the gold answer
'NoAns_exact': True, # Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': True, # The F-score of predicted tokens versus the gold answer
'best_exact': True, # Best exact match (with varying threshold)
'best_f1': True, # Best F1 (with varying threshold)
}
...@@ -29,4 +29,4 @@ def test_evaluator(taskname, Task): ...@@ -29,4 +29,4 @@ def test_evaluator(taskname, Task):
lm.loglikelihood = ll_fn lm.loglikelihood = ll_fn
evaluator.evaluate(lm, task_dict, False, 0, 10) evaluator.evaluate(lm, task_dict, False, 0, 10)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment