Commit b0e715ee authored by sdtblck's avatar sdtblck
Browse files

name changes

parent ff2d2ada
...@@ -4,14 +4,14 @@ import random ...@@ -4,14 +4,14 @@ import random
from ..base import Dataset from ..base import Dataset
class NLP_TASK(Dataset): class HF_Dataset(Dataset):
NLP_PATH = None DATASET_PATH = None
NLP_NAME = None DATASET_NAME = None
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self._training_docs = None self._training_docs = None
self.data = datasets.load_dataset(path=self.NLP_PATH, name=self.NLP_NAME) self.data = datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)
def has_training_docs(self): def has_training_docs(self):
"""Whether the task has a training set""" """Whether the task has a training set"""
......
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
from scipy.stats import pearsonr, spearmanr from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import f1_score, matthews_corrcoef from sklearn.metrics import f1_score, matthews_corrcoef
from tqdm import auto as tqdm_lib from tqdm import auto as tqdm_lib
from . common import NLP_TASK, simple_accuracy_metric, yesno from . common import HF_Dataset, simple_accuracy_metric, yesno
def get_accuracy_and_f1(preds, golds): def get_accuracy_and_f1(preds, golds):
...@@ -22,18 +22,9 @@ def get_accuracy_and_f1(preds, golds): ...@@ -22,18 +22,9 @@ def get_accuracy_and_f1(preds, golds):
} }
class CoLA(NLP_TASK): class CoLA(HF_Dataset):
NLP_PATH = "glue" DATASET_PATH = "glue"
NLP_NAME = "cola" DATASET_NAME = "cola"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def fewshot_description(self): def fewshot_description(self):
return "Does this sentence make sense?:\tTrue or False?" return "Does this sentence make sense?:\tTrue or False?"
...@@ -64,9 +55,9 @@ class CoLA(NLP_TASK): ...@@ -64,9 +55,9 @@ class CoLA(NLP_TASK):
} }
class MNLI(NLP_TASK): class MNLI(HF_Dataset):
NLP_PATH = "glue" DATASET_PATH = "glue"
NLP_NAME = "mnli" DATASET_NAME = "mnli"
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -79,11 +70,11 @@ class MNLI(NLP_TASK): ...@@ -79,11 +70,11 @@ class MNLI(NLP_TASK):
def validation_docs(self): def validation_docs(self):
if self.has_validation_docs(): if self.has_validation_docs():
return self._load_nlp_dataset()["validation_matched"] return self.data["validation_matched"]
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
return self._load_nlp_dataset()["test_matched"] return self.data["test_matched"]
def doc_to_text(self, doc, include_target=True): def doc_to_text(self, doc, include_target=True):
text = "{}\nquestion:\t{}\tTrue, False or Neither?\nanswer:".format( text = "{}\nquestion:\t{}\tTrue, False or Neither?\nanswer:".format(
...@@ -115,9 +106,9 @@ class MNLI(NLP_TASK): ...@@ -115,9 +106,9 @@ class MNLI(NLP_TASK):
return simple_accuracy_metric(preds=preds, golds=golds) return simple_accuracy_metric(preds=preds, golds=golds)
class MRPC(NLP_TASK): class MRPC(HF_Dataset):
NLP_PATH = "glue" DATASET_PATH = "glue"
NLP_NAME = "mrpc" DATASET_NAME = "mrpc"
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -153,9 +144,9 @@ class MRPC(NLP_TASK): ...@@ -153,9 +144,9 @@ class MRPC(NLP_TASK):
return get_accuracy_and_f1(preds=preds, golds=golds) return get_accuracy_and_f1(preds=preds, golds=golds)
class RTE(NLP_TASK): class RTE(HF_Dataset):
NLP_PATH = "glue" DATASET_PATH = "glue"
NLP_NAME = "rte" DATASET_NAME = "rte"
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -190,9 +181,9 @@ class RTE(NLP_TASK): ...@@ -190,9 +181,9 @@ class RTE(NLP_TASK):
return simple_accuracy_metric(preds=preds, golds=golds) return simple_accuracy_metric(preds=preds, golds=golds)
class QNLI(NLP_TASK): class QNLI(HF_Dataset):
NLP_PATH = "glue" DATASET_PATH = "glue"
NLP_NAME = "qnli" DATASET_NAME = "qnli"
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -227,9 +218,9 @@ class QNLI(NLP_TASK): ...@@ -227,9 +218,9 @@ class QNLI(NLP_TASK):
return simple_accuracy_metric(preds=preds, golds=golds) return simple_accuracy_metric(preds=preds, golds=golds)
class QQP(NLP_TASK): class QQP(HF_Dataset):
NLP_PATH = "glue" DATASET_PATH = "glue"
NLP_NAME = "qqp" DATASET_NAME = "qqp"
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -265,9 +256,9 @@ class QQP(NLP_TASK): ...@@ -265,9 +256,9 @@ class QQP(NLP_TASK):
return get_accuracy_and_f1(preds=preds, golds=golds) return get_accuracy_and_f1(preds=preds, golds=golds)
class STSB(NLP_TASK): class STSB(HF_Dataset):
NLP_PATH = "glue" DATASET_PATH = "glue"
NLP_NAME = "stsb" DATASET_NAME = "stsb"
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -322,9 +313,9 @@ class STSB(NLP_TASK): ...@@ -322,9 +313,9 @@ class STSB(NLP_TASK):
} }
class SST(NLP_TASK): class SST(HF_Dataset):
NLP_PATH = "glue" DATASET_PATH = "glue"
NLP_NAME = "sst2" DATASET_NAME = "sst2"
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -359,9 +350,9 @@ class SST(NLP_TASK): ...@@ -359,9 +350,9 @@ class SST(NLP_TASK):
return simple_accuracy_metric(preds=preds, golds=golds) return simple_accuracy_metric(preds=preds, golds=golds)
class WNLI(NLP_TASK): class WNLI(HF_Dataset):
NLP_PATH = "glue" DATASET_PATH = "glue"
NLP_NAME = "wnli" DATASET_NAME = "wnli"
def has_training_docs(self): def has_training_docs(self):
return True return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment