base.py 2.36 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
3
import abc
import random

Jason Phang's avatar
gpt3  
Jason Phang committed
4

Leo Gao's avatar
Leo Gao committed
5
6
class LM(abc.ABC):
    @abc.abstractmethod
Jason Phang's avatar
gpt3  
Jason Phang committed
7
8
9
10
11
12
13
14
15
16
    def generate(self, context, max_gen_length):
        """Conditional text generation with an LM

        :param context: str
            Context string for conditional generation
        :param max_gen_length: int
            Maximum number of tokens to generate
        :return: str

        """
Leo Gao's avatar
Leo Gao committed
17
18
19
        pass

    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
20
    def loglikelihood(self, context, continuation):
Jason Phang's avatar
gpt3  
Jason Phang committed
21
22
23
24
25
26
27
28
29
30
31
        """Compute log-prob of a generation a continuation from a context

        Assume that the final text will simple be
            context + continuation

        :param context: str
            Context string for conditional generation
        :param continuation: str
            Maximum number of tokens to generate
        :return: float
        """
Leo Gao's avatar
Leo Gao committed
32
33
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
34
35
36
37
38
39
40
41
42
43
44
    @classmethod
    def create_from_arg_string(cls, arg_string):
        """Constructor method, in case models need additional arguments
        e.g. OpenAI API engine, paths for loading, other params

        :param arg_string: str
            Left up to individual model class to handle

        """
        return cls()

Leo Gao's avatar
Leo Gao committed
45
46

class Dataset(abc.ABC):
47
48
49
50
51
52
53
54
    @abc.abstractmethod
    def has_training_docs(self):
        pass
    
    @abc.abstractmethod
    def has_validation_docs(self):
        pass

Leo Gao's avatar
Leo Gao committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    @abc.abstractmethod
    def training_docs(self):
        pass
    
    @abc.abstractmethod
    def validation_docs(self):
        pass
    
    @abc.abstractmethod
    def test_docs(self):
        pass
    
    def fewshot_examples(self, k):
        traindocs = list(self.training_docs())
        random.seed(123)
        random.shuffle(traindocs)

        return traindocs[:k]
    
    @abc.abstractmethod
    def fewshot_description(self):
        pass

    @abc.abstractmethod
    def doc_to_text(self, doc, include_target=True):
        pass
Leo Gao's avatar
Leo Gao committed
81
82
    
    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
83
    def evaluate(self, docs, lm, provide_description, num_fewshot):
Leo Gao's avatar
Leo Gao committed
84
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98


class Registry:
    def __init__(self, registry_name):
        self.registry_name = registry_name
        self.registry = {}

    def register(self, name):
        def register_cls(new_cls):
            if name in self.registry:
                raise ValueError('Cannot register duplicate ({})'.format(self.registry_name, name))
            self.registry[name] = new_cls
            return new_cls
        return register_cls