Unverified Commit eb9a5224 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #45 from cfoster0/winogrande

Add Winogrande dataset
parents 80f5fc3b 0e0e37f4
...@@ -4,6 +4,7 @@ from . import arc ...@@ -4,6 +4,7 @@ from . import arc
from . import race from . import race
from . import webqs from . import webqs
from . import anli from . import anli
from . import winogrande
from . import quac from . import quac
from . import hellaswag from . import hellaswag
from . import openbookqa from . import openbookqa
...@@ -36,6 +37,7 @@ TASK_REGISTRY = { ...@@ -36,6 +37,7 @@ TASK_REGISTRY = {
"squad": squad.SQuAD, "squad": squad.SQuAD,
"race": race.RACE, "race": race.RACE,
"webqs": webqs.WebQs, "webqs": webqs.WebQs,
"winogrande": winogrande.Winogrande,
"anli_r1": anli.ANLIRound1, "anli_r1": anli.ANLIRound1,
"anli_r2": anli.ANLIRound2, "anli_r2": anli.ANLIRound2,
"anli_r3": anli.ANLIRound3, "anli_r3": anli.ANLIRound3,
......
import numpy as np
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import f1_score, matthews_corrcoef
from tqdm import auto as tqdm_lib
from . common import HFTask, simple_accuracy_metric, yesno
class Winogrande(HFTask):
DATASET_PATH = "winogrande"
DATASET_NAME = "winogrande_xl"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self.has_training_docs():
return self.data["train"]
def validation_docs(self):
if self.has_validation_docs():
return self.data["validation"]
def test_docs(self):
if self.has_test_docs():
return self.data["test"]
def fewshot_description(self):
return "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in."
def doc_to_text(self, doc, include_target=True):
text = doc['sentence']
if include_target:
answer_n = doc['answer']
if answer_n == '1':
answer = doc['option1']
elif answer_n == '2':
answer = doc['option2']
else:
raise ValueError("Winogrande from HF datasets contained an invalid answer key")
text = text.replace("_", answer)
return text
def evaluate(self, docs, lm, provide_description, num_fewshot):
# TODO: Write evaluation function
raise NotImplementedError()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment