anli.py 3.51 KB
Newer Older
Jonathan Tow's avatar
Jonathan Tow committed
1
import numpy as np
&'s avatar
& committed
2
3
from lm_eval.base import rf
from ..metrics import mean
Leo Gao's avatar
Leo Gao committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from . common import HFTask

class ANLIBase(HFTask):
    DATASET_PATH = "anli"
    DATASET_NAME = None
    SPLIT = None

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

    def training_docs(self):
        if self.has_training_docs():
            if self._training_docs is None:
                self._training_docs = list(self.data["train_r" + str(self.SPLIT)])
            return self._training_docs

    def validation_docs(self):
        if self.has_validation_docs():
            return self.data["dev_r" + str(self.SPLIT)]

    def test_docs(self):
        if self.has_test_docs():
            return self.data["test_r" + str(self.SPLIT)]

    def fewshot_description(self):
        # TODO: figure out description
        return ""

38
    def doc_to_text(self, doc):
Leo Gao's avatar
Leo Gao committed
39
40
41
42
        # OA does this a bit weirdly: they prepend "anli 1:  anli 1:  " to the beginning
        # of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly 
        # appended onto the question, with no "Answer:" or even a newline. Do we *really* 
        # want to do it exactly as OA did?
Leo Gao's avatar
Leo Gao committed
43
        return doc['premise'] + '\nQuestion: ' + doc['hypothesis'] + ' True, False, or Neither?\nAnswer:'
Leo Gao's avatar
Leo Gao committed
44

45
    def doc_to_target(self, doc):
Jonathan Tow's avatar
Jonathan Tow committed
46
47
48
        # True = entailment
        # False = contradiction
        # Neither = neutral
49
        return " " + ["True", "Neither", "False"][doc['label']]
Leo Gao's avatar
Leo Gao committed
50

Leo Gao's avatar
Leo Gao committed
51
52
53
54
55
56
57
58
59
60
61
    def construct_requests(self, doc, ctx):
        """ Uses RequestFactory to construct Requests and returns an iterable of 
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural 
            language description, as well as the few shot examples, and the question
            part of the document for `doc`. 
        """
Jonathan Tow's avatar
Jonathan Tow committed
62
63
64
65
        ll_true, _ = rf.loglikelihood(ctx, " True") 
        ll_neither, _ = rf.loglikelihood(ctx, " Neither") 
        ll_false, _ = rf.loglikelihood(ctx, " False") 
        return ll_true, ll_neither, ll_false
Leo Gao's avatar
Leo Gao committed
66
67
68
69
70
71
72
73
74
75
76
    
    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a 
        dict where keys are the names of submetrics and values are the values of 
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
Jonathan Tow's avatar
Jonathan Tow committed
77
78
79
80
81
        gold = doc["label"]
        pred = np.argmax(results)
        return {
            "acc": pred == gold
        }
Leo Gao's avatar
Leo Gao committed
82
83
84
85
86
87
88

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
            A dictionary where keys are the names of submetrics and values are 
            functions that aggregate a list of metrics
        """
Jonathan Tow's avatar
Jonathan Tow committed
89
90
91
        return {
            "acc": mean
        }
Leo Gao's avatar
Leo Gao committed
92
93
94
95
96
97
98

    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are 
            whether a higher value of the submetric is better
        """
Jonathan Tow's avatar
Jonathan Tow committed
99
100
101
        return {
            "acc": True
        }
Leo Gao's avatar
Leo Gao committed
102
103
104
105
106
107
108
109

class ANLIRound1(ANLIBase):
    SPLIT = 1

class ANLIRound2(ANLIBase):
    SPLIT = 2

class ANLIRound3(ANLIBase):
Jonathan Tow's avatar
Jonathan Tow committed
110
    SPLIT = 3