Commit 5a6c172e authored by sdtblck's avatar sdtblck
Browse files

change HF_Dataset to HFTask

parent 08f2ed6e
...@@ -4,7 +4,7 @@ import random ...@@ -4,7 +4,7 @@ import random
from ..base import Dataset from ..base import Dataset
class HF_Dataset(Dataset): class HFTask(Dataset):
DATASET_PATH = None DATASET_PATH = None
DATASET_NAME = None DATASET_NAME = None
......
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
from scipy.stats import pearsonr, spearmanr from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import f1_score, matthews_corrcoef from sklearn.metrics import f1_score, matthews_corrcoef
from tqdm import auto as tqdm_lib from tqdm import auto as tqdm_lib
from . common import HF_Dataset, simple_accuracy_metric, yesno from . common import HFTask, simple_accuracy_metric, yesno
def get_accuracy_and_f1(preds, golds): def get_accuracy_and_f1(preds, golds):
...@@ -22,7 +22,7 @@ def get_accuracy_and_f1(preds, golds): ...@@ -22,7 +22,7 @@ def get_accuracy_and_f1(preds, golds):
} }
class CoLA(HF_Dataset): class CoLA(HFTask):
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "cola" DATASET_NAME = "cola"
...@@ -55,7 +55,7 @@ class CoLA(HF_Dataset): ...@@ -55,7 +55,7 @@ class CoLA(HF_Dataset):
} }
class MNLI(HF_Dataset): class MNLI(HFTask):
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "mnli" DATASET_NAME = "mnli"
...@@ -106,7 +106,7 @@ class MNLI(HF_Dataset): ...@@ -106,7 +106,7 @@ class MNLI(HF_Dataset):
return simple_accuracy_metric(preds=preds, golds=golds) return simple_accuracy_metric(preds=preds, golds=golds)
class MRPC(HF_Dataset): class MRPC(HFTask):
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "mrpc" DATASET_NAME = "mrpc"
...@@ -144,7 +144,7 @@ class MRPC(HF_Dataset): ...@@ -144,7 +144,7 @@ class MRPC(HF_Dataset):
return get_accuracy_and_f1(preds=preds, golds=golds) return get_accuracy_and_f1(preds=preds, golds=golds)
class RTE(HF_Dataset): class RTE(HFTask):
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "rte" DATASET_NAME = "rte"
...@@ -181,7 +181,7 @@ class RTE(HF_Dataset): ...@@ -181,7 +181,7 @@ class RTE(HF_Dataset):
return simple_accuracy_metric(preds=preds, golds=golds) return simple_accuracy_metric(preds=preds, golds=golds)
class QNLI(HF_Dataset): class QNLI(HFTask):
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "qnli" DATASET_NAME = "qnli"
...@@ -218,7 +218,7 @@ class QNLI(HF_Dataset): ...@@ -218,7 +218,7 @@ class QNLI(HF_Dataset):
return simple_accuracy_metric(preds=preds, golds=golds) return simple_accuracy_metric(preds=preds, golds=golds)
class QQP(HF_Dataset): class QQP(HFTask):
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "qqp" DATASET_NAME = "qqp"
...@@ -256,7 +256,7 @@ class QQP(HF_Dataset): ...@@ -256,7 +256,7 @@ class QQP(HF_Dataset):
return get_accuracy_and_f1(preds=preds, golds=golds) return get_accuracy_and_f1(preds=preds, golds=golds)
class STSB(HF_Dataset): class STSB(HFTask):
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "stsb" DATASET_NAME = "stsb"
...@@ -313,7 +313,7 @@ class STSB(HF_Dataset): ...@@ -313,7 +313,7 @@ class STSB(HF_Dataset):
} }
class SST(HF_Dataset): class SST(HFTask):
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "sst2" DATASET_NAME = "sst2"
...@@ -350,7 +350,7 @@ class SST(HF_Dataset): ...@@ -350,7 +350,7 @@ class SST(HF_Dataset):
return simple_accuracy_metric(preds=preds, golds=golds) return simple_accuracy_metric(preds=preds, golds=golds)
class WNLI(HF_Dataset): class WNLI(HFTask):
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "wnli" DATASET_NAME = "wnli"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment