race.py 4.45 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
import collections
2
import datasets
Jon Tow's avatar
Jon Tow committed
3
import numpy as np
&'s avatar
& committed
4
5
from lm_eval.base import rf
from ..metrics import mean
Jon Tow's avatar
Jon Tow committed
6
from . common import HFTask
Leo Gao's avatar
Leo Gao committed
7
8
9
10
11
12
13
14


class each:
    def __init__(self, f):
        self.f = f

    def __rrshift__(self, other):
        return list(map(self.f, other))
Leo Gao's avatar
Leo Gao committed
15
16


17
class RACE(HFTask):
Leo Gao's avatar
Leo Gao committed
18
    VERSION = 0
Leo Gao's avatar
Leo Gao committed
19
20
    DATASET_PATH = "race"
    DATASET_NAME = "high"
Leo Gao's avatar
Leo Gao committed
21
22

    cache = {}
Jon Tow's avatar
Jon Tow committed
23
    letter_to_num = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
Leo Gao's avatar
Leo Gao committed
24

Leo Gao's avatar
Leo Gao committed
25
    assert datasets.__version__ == "1.15.1", "RACE requires datasets==1.15.1!"
Leo Gao's avatar
Leo Gao committed
26

Leo Gao's avatar
Leo Gao committed
27
28
29
30
31
32
33
34
35
36
    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

    def _collate_data(self, set):
Leo Gao's avatar
Leo Gao committed
37
38
        if set in self.cache:
            return self.cache[set]
Leo Gao's avatar
Leo Gao committed
39
40
41
42
43
        # One big issue with HF's implementation of this dataset: it makes a
        # separate document for each question; meanwhile, in the GPT3 paper it
        # is shown that one document is made per passage.

        r = collections.defaultdict(list)
44
        for item in datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)[set]:
Leo Gao's avatar
Leo Gao committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
            r[item['article']].append(item)
        
        res = list(r.values() >> each(lambda x: {
            'article': x[0]['article'],
            'problems': x >> each(lambda y: {
                'question': y['question'],
                'answer': y['answer'],
                'options': y['options'],
            })
        }))

        self.cache[set] = res
        return res

    def training_docs(self):
        return self._collate_data("train")

    def validation_docs(self):
        return self._collate_data("validation")

    def test_docs(self):
        return self._collate_data("test")

Jon Tow's avatar
Jon Tow committed
68
69
70
71
72
73
74
75
76
    @classmethod
    def get_answer_option(cls, problem):
        answer = cls.letter_to_num[problem['answer']]
        return problem['options'][answer]

    @classmethod
    def last_problem(cls, doc):
        return doc['problems'][-1]

77
    def doc_to_text(self, doc):
Jon Tow's avatar
Jon Tow committed
78
79
        text = 'Article: ' + doc['article'] + '\n\n'
        for problem in doc['problems'][:-1]:
Leo Gao's avatar
Leo Gao committed
80
81
82
83
84
85
            if problem['question'][-6:] == '  _  .':
                text += problem['question'][-5:] + self.get_answer_option(problem) + '\n'
            else:
                question = 'Question: ' + problem['question'] + '\n'
                answer = 'Answer: ' + self.get_answer_option(problem) + '\n'
                text += question + answer
Leo Gao's avatar
Leo Gao committed
86
        text += self.last_problem(doc)['question']
Jon Tow's avatar
Jon Tow committed
87
        return text
Leo Gao's avatar
Leo Gao committed
88

89
    def doc_to_target(self, doc):
Jon Tow's avatar
Jon Tow committed
90
        return " " + self.get_answer_option(self.last_problem(doc))
Leo Gao's avatar
Leo Gao committed
91

Leo Gao's avatar
Leo Gao committed
92
93
94
95
96
97
98
99
100
101
102
    def construct_requests(self, doc, ctx):
        """ Uses RequestFactory to construct Requests and returns an iterable of 
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural 
            language description, as well as the few shot examples, and the question
            part of the document for `doc`. 
        """
Jon Tow's avatar
Jon Tow committed
103
104
105
106
107
108
109
        problem = self.last_problem(doc)
        ll_choices = [
            rf.loglikelihood(ctx, " " + problem['options'][i])[0]
            for i in range(4)
        ]
        return ll_choices

Leo Gao's avatar
Leo Gao committed
110
111
112
113
114
115
116
117
118
119
    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a 
        dict where keys are the names of submetrics and values are the values of 
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
Jon Tow's avatar
Jon Tow committed
120
121
122
123
124
        gold = self.letter_to_num[self.last_problem(doc)['answer']]
        pred = np.argmax(results)
        return {
            "acc": int(pred == gold)
        }
Leo Gao's avatar
Leo Gao committed
125
126
127
128
129
130
131

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
            A dictionary where keys are the names of submetrics and values are 
            functions that aggregate a list of metrics
        """
Jon Tow's avatar
Jon Tow committed
132
133
134
        return {
            "acc": mean
        }
Leo Gao's avatar
Leo Gao committed
135
136
137
138
139
140
141

    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are 
            whether a higher value of the submetric is better
        """
Jon Tow's avatar
Jon Tow committed
142
143
144
        return {
            "acc": True
        }