Commit b0e715ee authored by sdtblck's avatar sdtblck
Browse files

name changes

parent ff2d2ada
......@@ -4,14 +4,14 @@ import random
from ..base import Dataset
class NLP_TASK(Dataset):
NLP_PATH = None
NLP_NAME = None
class HF_Dataset(Dataset):
DATASET_PATH = None
DATASET_NAME = None
def __init__(self):
super().__init__()
self._training_docs = None
self.data = datasets.load_dataset(path=self.NLP_PATH, name=self.NLP_NAME)
self.data = datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)
def has_training_docs(self):
"""Whether the task has a training set"""
......
......@@ -2,7 +2,7 @@ import numpy as np
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import f1_score, matthews_corrcoef
from tqdm import auto as tqdm_lib
from . common import NLP_TASK, simple_accuracy_metric, yesno
from . common import HF_Dataset, simple_accuracy_metric, yesno
def get_accuracy_and_f1(preds, golds):
......@@ -22,18 +22,9 @@ def get_accuracy_and_f1(preds, golds):
}
class CoLA(NLP_TASK):
NLP_PATH = "glue"
NLP_NAME = "cola"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
class CoLA(HF_Dataset):
DATASET_PATH = "glue"
DATASET_NAME = "cola"
def fewshot_description(self):
return "Does this sentence make sense?:\tTrue or False?"
......@@ -64,9 +55,9 @@ class CoLA(NLP_TASK):
}
class MNLI(NLP_TASK):
NLP_PATH = "glue"
NLP_NAME = "mnli"
class MNLI(HF_Dataset):
DATASET_PATH = "glue"
DATASET_NAME = "mnli"
def has_training_docs(self):
return True
......@@ -79,11 +70,11 @@ class MNLI(NLP_TASK):
def validation_docs(self):
if self.has_validation_docs():
return self._load_nlp_dataset()["validation_matched"]
return self.data["validation_matched"]
def test_docs(self):
if self.has_test_docs():
return self._load_nlp_dataset()["test_matched"]
return self.data["test_matched"]
def doc_to_text(self, doc, include_target=True):
text = "{}\nquestion:\t{}\tTrue, False or Neither?\nanswer:".format(
......@@ -115,9 +106,9 @@ class MNLI(NLP_TASK):
return simple_accuracy_metric(preds=preds, golds=golds)
class MRPC(NLP_TASK):
NLP_PATH = "glue"
NLP_NAME = "mrpc"
class MRPC(HF_Dataset):
DATASET_PATH = "glue"
DATASET_NAME = "mrpc"
def has_training_docs(self):
return True
......@@ -153,9 +144,9 @@ class MRPC(NLP_TASK):
return get_accuracy_and_f1(preds=preds, golds=golds)
class RTE(NLP_TASK):
NLP_PATH = "glue"
NLP_NAME = "rte"
class RTE(HF_Dataset):
DATASET_PATH = "glue"
DATASET_NAME = "rte"
def has_training_docs(self):
return True
......@@ -190,9 +181,9 @@ class RTE(NLP_TASK):
return simple_accuracy_metric(preds=preds, golds=golds)
class QNLI(NLP_TASK):
NLP_PATH = "glue"
NLP_NAME = "qnli"
class QNLI(HF_Dataset):
DATASET_PATH = "glue"
DATASET_NAME = "qnli"
def has_training_docs(self):
return True
......@@ -227,9 +218,9 @@ class QNLI(NLP_TASK):
return simple_accuracy_metric(preds=preds, golds=golds)
class QQP(NLP_TASK):
NLP_PATH = "glue"
NLP_NAME = "qqp"
class QQP(HF_Dataset):
DATASET_PATH = "glue"
DATASET_NAME = "qqp"
def has_training_docs(self):
return True
......@@ -265,9 +256,9 @@ class QQP(NLP_TASK):
return get_accuracy_and_f1(preds=preds, golds=golds)
class STSB(NLP_TASK):
NLP_PATH = "glue"
NLP_NAME = "stsb"
class STSB(HF_Dataset):
DATASET_PATH = "glue"
DATASET_NAME = "stsb"
def has_training_docs(self):
return True
......@@ -322,9 +313,9 @@ class STSB(NLP_TASK):
}
class SST(NLP_TASK):
NLP_PATH = "glue"
NLP_NAME = "sst2"
class SST(HF_Dataset):
DATASET_PATH = "glue"
DATASET_NAME = "sst2"
def has_training_docs(self):
return True
......@@ -359,9 +350,9 @@ class SST(NLP_TASK):
return simple_accuracy_metric(preds=preds, golds=golds)
class WNLI(NLP_TASK):
NLP_PATH = "glue"
NLP_NAME = "wnli"
class WNLI(HF_Dataset):
DATASET_PATH = "glue"
DATASET_NAME = "wnli"
def has_training_docs(self):
return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment