base.py 3.79 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
3
import abc
import random

Jason Phang's avatar
gpt3  
Jason Phang committed
4

Leo Gao's avatar
Leo Gao committed
5
6
class LM(abc.ABC):
    @abc.abstractmethod
Jason Phang's avatar
gpt3  
Jason Phang committed
7
8
9
10
11
12
13
14
15
16
    def generate(self, context, max_gen_length):
        """Conditional text generation with an LM

        :param context: str
            Context string for conditional generation
        :param max_gen_length: int
            Maximum number of tokens to generate
        :return: str

        """
Leo Gao's avatar
Leo Gao committed
17
18
19
        pass

    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
20
    def loglikelihood(self, context, continuation):
Jason Phang's avatar
checkin  
Jason Phang committed
21
        """Compute log-likelihood of a generation a continuation from a context
Jason Phang's avatar
gpt3  
Jason Phang committed
22
23
24
25
26
27
28
29
30
31

        Assume that the final text will simple be
            context + continuation

        :param context: str
            Context string for conditional generation
        :param continuation: str
            Maximum number of tokens to generate
        :return: float
        """
Leo Gao's avatar
Leo Gao committed
32
33
        pass

Jason Phang's avatar
Jason Phang committed
34
35
36
37
38
39
40
41
42
43
    @classmethod
    def num_tokens(cls, string):
        """Return the number of tokens in a string, based on tokenization

        :param string: str
            Input string
        :return: int
        """
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
44
45
46
47
48
49
50
51
52
53
54
    @classmethod
    def create_from_arg_string(cls, arg_string):
        """Constructor method, in case models need additional arguments
        e.g. OpenAI API engine, paths for loading, other params

        :param arg_string: str
            Left up to individual model class to handle

        """
        return cls()

Leo Gao's avatar
Leo Gao committed
55
56

class Dataset(abc.ABC):
57
58
    @abc.abstractmethod
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
59
        """Whether the task has a training set"""
60
61
62
63
        pass
    
    @abc.abstractmethod
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
64
65
66
67
68
69
        """Whether the task has a validation set"""
        pass

    @abc.abstractmethod
    def has_test_docs(self):
        """Whether the task has a test set"""
70
71
        pass

Leo Gao's avatar
Leo Gao committed
72
73
    @abc.abstractmethod
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
74
75
76
77
78
        """

        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        pass
    
    @abc.abstractmethod
    def validation_docs(self):
        pass
    
    @abc.abstractmethod
    def test_docs(self):
        pass
    
    def fewshot_examples(self, k):
        traindocs = list(self.training_docs())
        random.shuffle(traindocs)
        return traindocs[:k]

    @abc.abstractmethod
    def doc_to_text(self, doc, include_target=True):
        pass
Leo Gao's avatar
Leo Gao committed
97
98
    
    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
99
    def evaluate(self, docs, lm, provide_description, num_fewshot):
Jason Phang's avatar
checkin  
Jason Phang committed
100
101
102
103
104
105
106
107
108
109
110
111
        """Take iterable of docs and evaluates, returning a dict with the following format:

        {
            "major": float,
            "minor": dict,
            "higher_is_better": bool,
        }

        * `major` should be a single, representative number, for programmatic comparison
        * `minor` should be a dictionary containing all relevant sub-metrics
        * `higher_is_better` determines whether a higher metric is better
        """
Leo Gao's avatar
Leo Gao committed
112
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
113

Jason Phang's avatar
Jason Phang committed
114
    def fewshot_description(self):
Jason Phang's avatar
checkin  
Jason Phang committed
115
116
        return ""

Jason Phang's avatar
Jason Phang committed
117
    def fewshot_context(self, doc, num_fewshot, provide_description):
Jason Phang's avatar
Jason Phang committed
118
        raw_description = self.fewshot_description()
Jason Phang's avatar
Jason Phang committed
119
        description = (raw_description + "\n===\n\n") if provide_description and raw_description else ""
Jason Phang's avatar
Jason Phang committed
120
121
122
123
124
        labeled_examples = "\n\n".join(
            map(self.doc_to_text, self.fewshot_examples(k=num_fewshot))
        ) + "\n\n"
        example = self.doc_to_text(doc, include_target=False).strip()
        return description + labeled_examples + example
Jason Phang's avatar
checkin  
Jason Phang committed
125

Jason Phang's avatar
gpt3  
Jason Phang committed
126
127
128
129
130
131
132
133
134
135
136
137
138

class Registry:
    def __init__(self, registry_name):
        self.registry_name = registry_name
        self.registry = {}

    def register(self, name):
        def register_cls(new_cls):
            if name in self.registry:
                raise ValueError('Cannot register duplicate ({})'.format(self.registry_name, name))
            self.registry[name] = new_cls
            return new_cls
        return register_cls