anli.py 3.53 KB
Newer Older
Jonathan Tow's avatar
Jonathan Tow committed
1
import numpy as np
&'s avatar
& committed
2
3
from lm_eval.base import rf
from ..metrics import mean
Leo Gao's avatar
Leo Gao committed
4
5
from . common import HFTask

6

Leo Gao's avatar
Leo Gao committed
7
class ANLIBase(HFTask):
Leo Gao's avatar
Leo Gao committed
8
    VERSION = 0
Leo Gao's avatar
Leo Gao committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    DATASET_PATH = "anli"
    DATASET_NAME = None
    SPLIT = None

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

    def training_docs(self):
        if self.has_training_docs():
            if self._training_docs is None:
                self._training_docs = list(self.data["train_r" + str(self.SPLIT)])
            return self._training_docs

    def validation_docs(self):
        if self.has_validation_docs():
            return self.data["dev_r" + str(self.SPLIT)]

    def test_docs(self):
        if self.has_test_docs():
            return self.data["test_r" + str(self.SPLIT)]

    def fewshot_description(self):
        # TODO: figure out description
        return ""

40
    def doc_to_text(self, doc):
Leo Gao's avatar
Leo Gao committed
41
42
43
44
        # OA does this a bit weirdly: they prepend "anli 1:  anli 1:  " to the beginning
        # of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly 
        # appended onto the question, with no "Answer:" or even a newline. Do we *really* 
        # want to do it exactly as OA did?
Leo Gao's avatar
Leo Gao committed
45
        return doc['premise'] + '\nQuestion: ' + doc['hypothesis'] + ' True, False, or Neither?\nAnswer:'
Leo Gao's avatar
Leo Gao committed
46

47
    def doc_to_target(self, doc):
Jonathan Tow's avatar
Jonathan Tow committed
48
49
50
        # True = entailment
        # False = contradiction
        # Neither = neutral
51
        return " " + ["True", "Neither", "False"][doc['label']]
Leo Gao's avatar
Leo Gao committed
52

Leo Gao's avatar
Leo Gao committed
53
54
55
56
57
58
59
60
61
62
63
    def construct_requests(self, doc, ctx):
        """ Uses RequestFactory to construct Requests and returns an iterable of 
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural 
            language description, as well as the few shot examples, and the question
            part of the document for `doc`. 
        """
Jonathan Tow's avatar
Jonathan Tow committed
64
65
66
67
        ll_true, _ = rf.loglikelihood(ctx, " True") 
        ll_neither, _ = rf.loglikelihood(ctx, " Neither") 
        ll_false, _ = rf.loglikelihood(ctx, " False") 
        return ll_true, ll_neither, ll_false
Leo Gao's avatar
Leo Gao committed
68
69
70
71
72
73
74
75
76
77
78
    
    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a 
        dict where keys are the names of submetrics and values are the values of 
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
Jonathan Tow's avatar
Jonathan Tow committed
79
80
81
82
83
        gold = doc["label"]
        pred = np.argmax(results)
        return {
            "acc": pred == gold
        }
Leo Gao's avatar
Leo Gao committed
84
85
86
87
88
89
90

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
            A dictionary where keys are the names of submetrics and values are 
            functions that aggregate a list of metrics
        """
Jonathan Tow's avatar
Jonathan Tow committed
91
92
93
        return {
            "acc": mean
        }
Leo Gao's avatar
Leo Gao committed
94
95
96
97
98
99
100

    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are 
            whether a higher value of the submetric is better
        """
Jonathan Tow's avatar
Jonathan Tow committed
101
102
103
        return {
            "acc": True
        }
Leo Gao's avatar
Leo Gao committed
104
105
106
107
108
109
110
111

class ANLIRound1(ANLIBase):
    SPLIT = 1

class ANLIRound2(ANLIBase):
    SPLIT = 2

class ANLIRound3(ANLIBase):
Jonathan Tow's avatar
Jonathan Tow committed
112
    SPLIT = 3