base.py 3.53 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
3
import abc
import random

Jason Phang's avatar
gpt3  
Jason Phang committed
4

Leo Gao's avatar
Leo Gao committed
5
6
class LM(abc.ABC):
    @abc.abstractmethod
Jason Phang's avatar
gpt3  
Jason Phang committed
7
8
9
10
11
12
13
14
15
16
    def generate(self, context, max_gen_length):
        """Conditional text generation with an LM

        :param context: str
            Context string for conditional generation
        :param max_gen_length: int
            Maximum number of tokens to generate
        :return: str

        """
Leo Gao's avatar
Leo Gao committed
17
18
19
        pass

    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
20
    def loglikelihood(self, context, continuation):
Jason Phang's avatar
checkin  
Jason Phang committed
21
        """Compute log-likelihood of a generation a continuation from a context
Jason Phang's avatar
gpt3  
Jason Phang committed
22
23
24
25
26
27
28
29
30
31

        Assume that the final text will simple be
            context + continuation

        :param context: str
            Context string for conditional generation
        :param continuation: str
            Maximum number of tokens to generate
        :return: float
        """
Leo Gao's avatar
Leo Gao committed
32
33
        pass

Jason Phang's avatar
Jason Phang committed
34
35
36
37
38
39
40
41
42
43
    @classmethod
    def num_tokens(cls, string):
        """Return the number of tokens in a string, based on tokenization

        :param string: str
            Input string
        :return: int
        """
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
44
45
46
47
48
49
50
51
52
53
54
    @classmethod
    def create_from_arg_string(cls, arg_string):
        """Constructor method, in case models need additional arguments
        e.g. OpenAI API engine, paths for loading, other params

        :param arg_string: str
            Left up to individual model class to handle

        """
        return cls()

Leo Gao's avatar
Leo Gao committed
55
56

class Dataset(abc.ABC):
Leo Gao's avatar
Leo Gao committed
57
58
59
    @abc.abstractmethod
    def __init__(self):
        self.download()
sdtblck's avatar
sdtblck committed
60
61
62
63
64

    def download(self):
        """Downloads the task dataset if necessary"""
        pass

65
66
    @abc.abstractmethod
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
67
        """Whether the task has a training set"""
68
69
70
71
        pass
    
    @abc.abstractmethod
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
72
73
74
75
76
77
        """Whether the task has a validation set"""
        pass

    @abc.abstractmethod
    def has_test_docs(self):
        """Whether the task has a test set"""
78
79
        pass

Leo Gao's avatar
Leo Gao committed
80
81
    @abc.abstractmethod
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
82
83
84
85
86
        """

        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        pass
    
    @abc.abstractmethod
    def validation_docs(self):
        pass
    
    @abc.abstractmethod
    def test_docs(self):
        pass
    
    def fewshot_examples(self, k):
        traindocs = list(self.training_docs())
        random.shuffle(traindocs)
        return traindocs[:k]

    @abc.abstractmethod
    def doc_to_text(self, doc, include_target=True):
        pass
Leo Gao's avatar
Leo Gao committed
105
106
    
    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
107
    def evaluate(self, docs, lm, provide_description, num_fewshot):
Jason Phang's avatar
checkin  
Jason Phang committed
108
109
110
111
112
113
114
115
116
117
118
119
        """Take iterable of docs and evaluates, returning a dict with the following format:

        {
            "major": float,
            "minor": dict,
            "higher_is_better": bool,
        }

        * `major` should be a single, representative number, for programmatic comparison
        * `minor` should be a dictionary containing all relevant sub-metrics
        * `higher_is_better` determines whether a higher metric is better
        """
Leo Gao's avatar
Leo Gao committed
120
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
121

Jason Phang's avatar
Jason Phang committed
122
    def fewshot_description(self):
Jason Phang's avatar
checkin  
Jason Phang committed
123
124
        return ""

Jason Phang's avatar
Jason Phang committed
125
    def fewshot_context(self, doc, num_fewshot, provide_description):
Jason Phang's avatar
Jason Phang committed
126
        raw_description = self.fewshot_description()
Jason Phang's avatar
Jason Phang committed
127
        description = (raw_description + "\n===\n\n") if provide_description and raw_description else ""
Jason Phang's avatar
Jason Phang committed
128
129
130
131
        labeled_examples = "\n\n".join(
            map(self.doc_to_text, self.fewshot_examples(k=num_fewshot))
        ) + "\n\n"
        example = self.doc_to_text(doc, include_target=False).strip()
sdtblck's avatar
sdtblck committed
132
        return description + labeled_examples + example