Unverified Commit 0f30237a authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #104 from jon-tow/race-evaluation

Implement `RACE` evaluation
parents e48b6082 48358352
...@@ -14,8 +14,10 @@ from . import naturalqs ...@@ -14,8 +14,10 @@ from . import naturalqs
from . import sat from . import sat
from . import arithmetic from . import arithmetic
from . import lambada from . import lambada
from . import race
from . import piqa from . import piqa
TASK_REGISTRY = { TASK_REGISTRY = {
# GLUE # GLUE
"cola": glue.CoLA, "cola": glue.CoLA,
...@@ -49,7 +51,7 @@ TASK_REGISTRY = { ...@@ -49,7 +51,7 @@ TASK_REGISTRY = {
# "openbookqa": openbookqa.OpenBookQA, # not implemented yet # "openbookqa": openbookqa.OpenBookQA, # not implemented yet
# "sat": sat.SATAnalogies, # not implemented yet # "sat": sat.SATAnalogies, # not implemented yet
# "squad": squad.SQuAD, # not implemented yet # "squad": squad.SQuAD, # not implemented yet
# "race": race.RACE, # not implemented yet "race": race.RACE,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet # "naturalqs": naturalqs.NaturalQs, # not implemented yet
# "webqs": webqs.WebQs, # not implemented yet # "webqs": webqs.WebQs, # not implemented yet
# "wsc273": wsc273.WinogradSchemaChallenge273, # not implemented yet # "wsc273": wsc273.WinogradSchemaChallenge273, # not implemented yet
......
from . common import HFTask
from ..utils_stream import X, each, apply, join, filt, one
import collections import collections
import datasets import datasets
import numpy as np
from lm_eval.base import rf, mean
from . common import HFTask
from ..utils_stream import each
class RACE(HFTask): class RACE(HFTask):
...@@ -9,6 +11,7 @@ class RACE(HFTask): ...@@ -9,6 +11,7 @@ class RACE(HFTask):
DATASET_NAME = "high" DATASET_NAME = "high"
cache = {} cache = {}
letter_to_num = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -54,13 +57,26 @@ class RACE(HFTask): ...@@ -54,13 +57,26 @@ class RACE(HFTask):
# TODO: figure out description # TODO: figure out description
return "" return ""
@classmethod
def get_answer_option(cls, problem):
answer = cls.letter_to_num[problem['answer']]
return problem['options'][answer]
@classmethod
def last_problem(cls, doc):
return doc['problems'][-1]
def doc_to_text(self, doc): def doc_to_text(self, doc):
# TODO: implement text = 'Article: ' + doc['article'] + '\n\n'
pass for problem in doc['problems'][:-1]:
question = 'Q: ' + problem['question'] + '\n\n'
answer = 'A: ' + self.get_answer_option(problem) + '\n\n'
text += question + answer
text += 'Q: ' + self.last_problem(doc)['question'] + '\n\n' + 'A:'
return text
def doc_to_target(self, doc): def doc_to_target(self, doc):
# TODO: implement return " " + self.get_answer_option(self.last_problem(doc))
pass
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
...@@ -73,9 +89,13 @@ class RACE(HFTask): ...@@ -73,9 +89,13 @@ class RACE(HFTask):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. problem = self.last_problem(doc)
raise NotImplementedError('Evaluation not implemented') ll_choices = [
rf.loglikelihood(ctx, " " + problem['options'][i])[0]
for i in range(4)
]
return ll_choices
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
...@@ -86,8 +106,11 @@ class RACE(HFTask): ...@@ -86,8 +106,11 @@ class RACE(HFTask):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# TODO: implement evaluation. gold = self.letter_to_num[self.last_problem(doc)['answer']]
raise NotImplementedError('Evaluation not implemented') pred = np.argmax(results)
return {
"acc": int(pred == gold)
}
def aggregation(self): def aggregation(self):
""" """
...@@ -95,8 +118,9 @@ class RACE(HFTask): ...@@ -95,8 +118,9 @@ class RACE(HFTask):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') "acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -104,5 +128,6 @@ class RACE(HFTask): ...@@ -104,5 +128,6 @@ class RACE(HFTask):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') "acc": True
\ No newline at end of file }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment