"ppstructure/table/table_metric/parallel.py" did not exist on "6127aad993bab109195292296b292781e03e74a2"
base.py 3.57 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
3
import abc
import random

Jason Phang's avatar
gpt3  
Jason Phang committed
4

Leo Gao's avatar
Leo Gao committed
5
6
class LM(abc.ABC):
    @abc.abstractmethod
Jason Phang's avatar
gpt3  
Jason Phang committed
7
8
9
10
11
12
13
14
15
16
    def generate(self, context, max_gen_length):
        """Conditional text generation with an LM

        :param context: str
            Context string for conditional generation
        :param max_gen_length: int
            Maximum number of tokens to generate
        :return: str

        """
Leo Gao's avatar
Leo Gao committed
17
18
19
        pass

    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
20
    def loglikelihood(self, context, continuation):
Jason Phang's avatar
checkin  
Jason Phang committed
21
        """Compute log-likelihood of a generation a continuation from a context
Jason Phang's avatar
gpt3  
Jason Phang committed
22
23
24
25
26
27
28
29
30
31

        Assume that the final text will simple be
            context + continuation

        :param context: str
            Context string for conditional generation
        :param continuation: str
            Maximum number of tokens to generate
        :return: float
        """
Leo Gao's avatar
Leo Gao committed
32
33
        pass

Jason Phang's avatar
gpt3  
Jason Phang committed
34
35
36
37
38
39
40
41
42
43
44
    @classmethod
    def create_from_arg_string(cls, arg_string):
        """Constructor method, in case models need additional arguments
        e.g. OpenAI API engine, paths for loading, other params

        :param arg_string: str
            Left up to individual model class to handle

        """
        return cls()

Leo Gao's avatar
Leo Gao committed
45
46

class Dataset(abc.ABC):
47
48
    @abc.abstractmethod
    def has_training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
49
        """Whether the task has a training set"""
50
51
52
53
        pass
    
    @abc.abstractmethod
    def has_validation_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
54
55
56
57
58
59
        """Whether the task has a validation set"""
        pass

    @abc.abstractmethod
    def has_test_docs(self):
        """Whether the task has a test set"""
60
61
        pass

Leo Gao's avatar
Leo Gao committed
62
63
    @abc.abstractmethod
    def training_docs(self):
Jason Phang's avatar
checkin  
Jason Phang committed
64
65
66
67
68
        """

        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
Leo Gao's avatar
Leo Gao committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        pass
    
    @abc.abstractmethod
    def validation_docs(self):
        pass
    
    @abc.abstractmethod
    def test_docs(self):
        pass
    
    def fewshot_examples(self, k):
        traindocs = list(self.training_docs())
        random.shuffle(traindocs)
        return traindocs[:k]

    @abc.abstractmethod
    def doc_to_text(self, doc, include_target=True):
        pass
Leo Gao's avatar
Leo Gao committed
87
88
    
    @abc.abstractmethod
Leo Gao's avatar
Leo Gao committed
89
    def evaluate(self, docs, lm, provide_description, num_fewshot):
Jason Phang's avatar
checkin  
Jason Phang committed
90
91
92
93
94
95
96
97
98
99
100
101
        """Take iterable of docs and evaluates, returning a dict with the following format:

        {
            "major": float,
            "minor": dict,
            "higher_is_better": bool,
        }

        * `major` should be a single, representative number, for programmatic comparison
        * `minor` should be a dictionary containing all relevant sub-metrics
        * `higher_is_better` determines whether a higher metric is better
        """
Leo Gao's avatar
Leo Gao committed
102
        pass
Jason Phang's avatar
gpt3  
Jason Phang committed
103

Jason Phang's avatar
Jason Phang committed
104
    def fewshot_description(self):
Jason Phang's avatar
checkin  
Jason Phang committed
105
106
        return ""

Jason Phang's avatar
Jason Phang committed
107
    def fewshot_context(self, doc, num_fewshot, provide_description):
Jason Phang's avatar
Jason Phang committed
108
109
        raw_description = self.fewshot_description()
        description = (raw_description + "\n\n") if provide_description and raw_description else ""
Jason Phang's avatar
Jason Phang committed
110
111
112
113
114
        labeled_examples = "\n\n".join(
            map(self.doc_to_text, self.fewshot_examples(k=num_fewshot))
        ) + "\n\n"
        example = self.doc_to_text(doc, include_target=False).strip()
        return description + labeled_examples + example
Jason Phang's avatar
checkin  
Jason Phang committed
115

Jason Phang's avatar
gpt3  
Jason Phang committed
116
117
118
119
120
121
122
123
124
125
126
127
128

class Registry:
    def __init__(self, registry_name):
        self.registry_name = registry_name
        self.registry = {}

    def register(self, name):
        def register_cls(new_cls):
            if name in self.registry:
                raise ValueError('Cannot register duplicate ({})'.format(self.registry_name, name))
            self.registry[name] = new_cls
            return new_cls
        return register_cls