base.py 3.12 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
3
import abc
import random

Jason Phang's avatar
gpt3  
Jason Phang committed
4

Leo Gao's avatar
Leo Gao committed
5
6
class LM(abc.ABC):
    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
7
    def loglikelihood(self, context, continuation):
Leo Gao's avatar
Leo Gao committed
8
        """Compute log-likelihood of generating a continuation from a context
Jason Phang's avatar
gpt3  
Jason Phang committed
9
10

        :param context: str
11
            Context string
Jason Phang's avatar
gpt3  
Jason Phang committed
12
        :param continuation: str
13
14
15
            The continuation over which log likelihood will be calculated. If 
            there is a word boundary, the space should be in the continuation. 
            For example, context="hello" continuation=" world" is correct.
Jason Phang's avatar
gpt3  
Jason Phang committed
16
17
        :return: float
        """
Leo Gao's avatar
Leo Gao committed
18
19
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
20
21
22
23
24
25
26
27
28
29
30
    @classmethod
    def create_from_arg_string(cls, arg_string):
        """Constructor method, in case models need additional arguments
        e.g. OpenAI API engine, paths for loading, other params

        :param arg_string: str
            Left up to individual model class to handle

        """
        return cls()

Leo Gao's avatar
Leo Gao committed
31
32

class Dataset(abc.ABC):
Leo Gao's avatar
Leo Gao committed
33
34
35
    @abc.abstractmethod
    def __init__(self):
        self.download()
Leo Gao's avatar
Leo Gao committed
36
        self._traindocs = None
sdtblck's avatar
sdtblck committed
37
38
39
40
41

    def download(self):
        """Downloads the task dataset if necessary"""
        pass

42
43
    @abc.abstractmethod
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
44
        """Whether the task has a training set"""
45
46
47
48
        pass
    
    @abc.abstractmethod
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
49
50
51
52
53
54
        """Whether the task has a validation set"""
        pass

    @abc.abstractmethod
    def has_test_docs(self):
        """Whether the task has a test set"""
55
56
        pass

Leo Gao's avatar
Leo Gao committed
57
58
    @abc.abstractmethod
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
59
60
61
62
63
        """

        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
64
65
66
67
68
69
70
71
72
73
74
        pass
    
    @abc.abstractmethod
    def validation_docs(self):
        pass
    
    @abc.abstractmethod
    def test_docs(self):
        pass
    
    def fewshot_examples(self, k):
Leo Gao's avatar
Leo Gao committed
75
76
77
78
        if self._traindocs is None:
            self._traindocs = list(self.training_docs())

        return random.sample(self._traindocs, k)
Leo Gao's avatar
Leo Gao committed
79
80
81
82

    @abc.abstractmethod
    def doc_to_text(self, doc, include_target=True):
        pass
Leo Gao's avatar
Leo Gao committed
83
84
    
    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
85
    def evaluate(self, docs, lm, provide_description, num_fewshot):
Jason Phang's avatar
checkin  
Jason Phang committed
86
87
88
89
90
91
92
93
94
95
96
97
        """Take iterable of docs and evaluates, returning a dict with the following format:

        {
            "major": float,
            "minor": dict,
            "higher_is_better": bool,
        }

        * `major` should be a single, representative number, for programmatic comparison
        * `minor` should be a dictionary containing all relevant sub-metrics
        * `higher_is_better` determines whether a higher metric is better
        """
Leo Gao's avatar
Leo Gao committed
98
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
99

Jason Phang's avatar
Jason Phang committed
100
    def fewshot_description(self):
Jason Phang's avatar
checkin  
Jason Phang committed
101
102
        return ""

Jason Phang's avatar
Jason Phang committed
103
    def fewshot_context(self, doc, num_fewshot, provide_description):
Jason Phang's avatar
Jason Phang committed
104
        raw_description = self.fewshot_description()
Jason Phang's avatar
Jason Phang committed
105
        description = (raw_description + "\n===\n\n") if provide_description and raw_description else ""
Jason Phang's avatar
Jason Phang committed
106
107
108
109
        labeled_examples = "\n\n".join(
            map(self.doc_to_text, self.fewshot_examples(k=num_fewshot))
        ) + "\n\n"
        example = self.doc_to_text(doc, include_target=False).strip()
sdtblck's avatar
sdtblck committed
110
        return description + labeled_examples + example