base.py 3.2 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
3
import abc
import random

Jason Phang's avatar
gpt3  
Jason Phang committed
4

Leo Gao's avatar
Leo Gao committed
5
6
class LM(abc.ABC):
    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
7
    def loglikelihood(self, context, continuation):
Jason Phang's avatar
checkin  
Jason Phang committed
8
        """Compute log-likelihood of a generation a continuation from a context
Jason Phang's avatar
gpt3  
Jason Phang committed
9
10
11
12
13

        Assume that the final text will simple be
            context + continuation

        :param context: str
14
            Context string
Jason Phang's avatar
gpt3  
Jason Phang committed
15
        :param continuation: str
16
17
18
            The continuation over which log likelihood will be calculated. If 
            there is a word boundary, the space should be in the continuation. 
            For example, context="hello" continuation=" world" is correct.
Jason Phang's avatar
gpt3  
Jason Phang committed
19
20
        :return: float
        """
Leo Gao's avatar
Leo Gao committed
21
22
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
23
24
25
26
27
28
29
30
31
32
33
    @classmethod
    def create_from_arg_string(cls, arg_string):
        """Constructor method, in case models need additional arguments
        e.g. OpenAI API engine, paths for loading, other params

        :param arg_string: str
            Left up to individual model class to handle

        """
        return cls()

Leo Gao's avatar
Leo Gao committed
34
35

class Dataset(abc.ABC):
Leo Gao's avatar
Leo Gao committed
36
37
38
    @abc.abstractmethod
    def __init__(self):
        self.download()
Leo Gao's avatar
Leo Gao committed
39
        self._traindocs = None
sdtblck's avatar
sdtblck committed
40
41
42
43
44

    def download(self):
        """Downloads the task dataset if necessary"""
        pass

45
46
    @abc.abstractmethod
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
47
        """Whether the task has a training set"""
48
49
50
51
        pass
    
    @abc.abstractmethod
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
52
53
54
55
56
57
        """Whether the task has a validation set"""
        pass

    @abc.abstractmethod
    def has_test_docs(self):
        """Whether the task has a test set"""
58
59
        pass

Leo Gao's avatar
Leo Gao committed
60
61
    @abc.abstractmethod
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
62
63
64
65
66
        """

        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
67
68
69
70
71
72
73
74
75
76
77
        pass
    
    @abc.abstractmethod
    def validation_docs(self):
        pass
    
    @abc.abstractmethod
    def test_docs(self):
        pass
    
    def fewshot_examples(self, k):
Leo Gao's avatar
Leo Gao committed
78
79
80
81
        if self._traindocs is None:
            self._traindocs = list(self.training_docs())

        return random.sample(self._traindocs, k)
Leo Gao's avatar
Leo Gao committed
82
83
84
85

    @abc.abstractmethod
    def doc_to_text(self, doc, include_target=True):
        pass
Leo Gao's avatar
Leo Gao committed
86
87
    
    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
88
    def evaluate(self, docs, lm, provide_description, num_fewshot):
Jason Phang's avatar
checkin  
Jason Phang committed
89
90
91
92
93
94
95
96
97
98
99
100
        """Take iterable of docs and evaluates, returning a dict with the following format:

        {
            "major": float,
            "minor": dict,
            "higher_is_better": bool,
        }

        * `major` should be a single, representative number, for programmatic comparison
        * `minor` should be a dictionary containing all relevant sub-metrics
        * `higher_is_better` determines whether a higher metric is better
        """
Leo Gao's avatar
Leo Gao committed
101
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
102

Jason Phang's avatar
Jason Phang committed
103
    def fewshot_description(self):
Jason Phang's avatar
checkin  
Jason Phang committed
104
105
        return ""

Jason Phang's avatar
Jason Phang committed
106
    def fewshot_context(self, doc, num_fewshot, provide_description):
Jason Phang's avatar
Jason Phang committed
107
        raw_description = self.fewshot_description()
Jason Phang's avatar
Jason Phang committed
108
        description = (raw_description + "\n===\n\n") if provide_description and raw_description else ""
Jason Phang's avatar
Jason Phang committed
109
110
111
112
        labeled_examples = "\n\n".join(
            map(self.doc_to_text, self.fewshot_examples(k=num_fewshot))
        ) + "\n\n"
        example = self.doc_to_text(doc, include_target=False).strip()
sdtblck's avatar
sdtblck committed
113
        return description + labeled_examples + example