common.py 1.47 KB
Newer Older
1
import datasets
2
from ..base import Task
Jason Phang's avatar
checkin  
Jason Phang committed
3
4


5
class HFTask(Task):
sdtblck's avatar
sdtblck committed
6
7
    DATASET_PATH = None
    DATASET_NAME = None
Jason Phang's avatar
checkin  
Jason Phang committed
8

Jason Phang's avatar
Jason Phang committed
9
    def __init__(self):
10
        self.data = None
Jason Phang's avatar
Jason Phang committed
11
        super().__init__()
12

13
    def download(self):
sdtblck's avatar
sdtblck committed
14
        self.data = datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)
Jason Phang's avatar
Jason Phang committed
15

16
17
18
19
20
21
22
23
24
25
26
    def has_training_docs(self):
        """Whether the task has a training set"""
        return True if "train" in self.data.keys() else False

    def has_validation_docs(self):
        """Whether the task has a validation set"""
        return True if "validation" in self.data.keys() else False

    def has_test_docs(self):
        """Whether the task has a test set"""
        return True if "test" in self.data.keys() else False
Jason Phang's avatar
checkin  
Jason Phang committed
27

28
29
30
    def _convert_standard(self, doc):
        return doc

Jason Phang's avatar
checkin  
Jason Phang committed
31
    def training_docs(self):
Jason Phang's avatar
Jason Phang committed
32
33
        # Cache training for faster few-shot.
        # If data is too large to fit in memory, override this method.
Jason Phang's avatar
checkin  
Jason Phang committed
34
        if self.has_training_docs():
35
            if self._training_docs is None:
Leo Gao's avatar
Leo Gao committed
36
                self._training_docs = list(map(self._convert_standard, self.data["train"]))
37
            return self._training_docs
Jason Phang's avatar
checkin  
Jason Phang committed
38
39
40

    def validation_docs(self):
        if self.has_validation_docs():
Leo Gao's avatar
Leo Gao committed
41
            return map(self._convert_standard, self.data["validation"])
Jason Phang's avatar
checkin  
Jason Phang committed
42
43
44

    def test_docs(self):
        if self.has_test_docs():
Leo Gao's avatar
Leo Gao committed
45
            return map(self._convert_standard, self.data["test"])
Jason Phang's avatar
checkin  
Jason Phang committed
46
47


Jason Phang's avatar
Jason Phang committed
48
49
50
51
52
def yesno(x):
    if x:
        return 'yes'
    else:
        return 'no'