Commit 5a6c172e authored by sdtblck's avatar sdtblck
Browse files

change HF_Dataset to HFTask

parent 08f2ed6e
......@@ -4,7 +4,7 @@ import random
from ..base import Dataset
class HF_Dataset(Dataset):
class HFTask(Dataset):
DATASET_PATH = None
DATASET_NAME = None
......
......@@ -2,7 +2,7 @@ import numpy as np
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import f1_score, matthews_corrcoef
from tqdm import auto as tqdm_lib
from . common import HF_Dataset, simple_accuracy_metric, yesno
from . common import HFTask, simple_accuracy_metric, yesno
def get_accuracy_and_f1(preds, golds):
......@@ -22,7 +22,7 @@ def get_accuracy_and_f1(preds, golds):
}
class CoLA(HF_Dataset):
class CoLA(HFTask):
DATASET_PATH = "glue"
DATASET_NAME = "cola"
......@@ -55,7 +55,7 @@ class CoLA(HF_Dataset):
}
class MNLI(HF_Dataset):
class MNLI(HFTask):
DATASET_PATH = "glue"
DATASET_NAME = "mnli"
......@@ -106,7 +106,7 @@ class MNLI(HF_Dataset):
return simple_accuracy_metric(preds=preds, golds=golds)
class MRPC(HF_Dataset):
class MRPC(HFTask):
DATASET_PATH = "glue"
DATASET_NAME = "mrpc"
......@@ -144,7 +144,7 @@ class MRPC(HF_Dataset):
return get_accuracy_and_f1(preds=preds, golds=golds)
class RTE(HF_Dataset):
class RTE(HFTask):
DATASET_PATH = "glue"
DATASET_NAME = "rte"
......@@ -181,7 +181,7 @@ class RTE(HF_Dataset):
return simple_accuracy_metric(preds=preds, golds=golds)
class QNLI(HF_Dataset):
class QNLI(HFTask):
DATASET_PATH = "glue"
DATASET_NAME = "qnli"
......@@ -218,7 +218,7 @@ class QNLI(HF_Dataset):
return simple_accuracy_metric(preds=preds, golds=golds)
class QQP(HF_Dataset):
class QQP(HFTask):
DATASET_PATH = "glue"
DATASET_NAME = "qqp"
......@@ -256,7 +256,7 @@ class QQP(HF_Dataset):
return get_accuracy_and_f1(preds=preds, golds=golds)
class STSB(HF_Dataset):
class STSB(HFTask):
DATASET_PATH = "glue"
DATASET_NAME = "stsb"
......@@ -313,7 +313,7 @@ class STSB(HF_Dataset):
}
class SST(HF_Dataset):
class SST(HFTask):
DATASET_PATH = "glue"
DATASET_NAME = "sst2"
......@@ -350,7 +350,7 @@ class SST(HF_Dataset):
return simple_accuracy_metric(preds=preds, golds=golds)
class WNLI(HF_Dataset):
class WNLI(HFTask):
DATASET_PATH = "glue"
DATASET_NAME = "wnli"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment