base.py 785 Bytes
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import abc
import random

class LM(abc.ABC):
    @abc.abstractmethod
    def generate(self, context, until):
        pass

    @abc.abstractmethod
    def nll_of(self, context, continuation):
        pass


class Dataset(abc.ABC):
    @abc.abstractmethod
    def training_docs(self):
        pass
    
    @abc.abstractmethod
    def validation_docs(self):
        pass
    
    @abc.abstractmethod
    def test_docs(self):
        pass
    
    def fewshot_examples(self, k):
        traindocs = list(self.training_docs())
        random.seed(123)
        random.shuffle(traindocs)

        return traindocs[:k]
    
    @abc.abstractmethod
    def fewshot_description(self):
        pass

    @abc.abstractmethod
    def doc_to_text(self, doc, include_target=True):
        pass