race.py 4.53 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
import collections
2
import datasets
Jon Tow's avatar
Jon Tow committed
3
import numpy as np
&'s avatar
& committed
4
5
from lm_eval.base import rf
from ..metrics import mean
Jon Tow's avatar
Jon Tow committed
6
from . common import HFTask
Leo Gao's avatar
Leo Gao committed
7
8
9
10
11
12
13
14
15
16
17
18
19

import os
from functools import reduce
import operator
from tqdm import tqdm
import json

class each:
    def __init__(self, f):
        self.f = f

    def __rrshift__(self, other):
        return list(map(self.f, other))
Leo Gao's avatar
Leo Gao committed
20
21


22
class RACE(HFTask):
Leo Gao's avatar
Leo Gao committed
23
24
    DATASET_PATH = "race"
    DATASET_NAME = "high"
Leo Gao's avatar
Leo Gao committed
25
26

    cache = {}
Jon Tow's avatar
Jon Tow committed
27
    letter_to_num = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
Leo Gao's avatar
Leo Gao committed
28
29
30
31
32
33
34
35
36
37
38

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

    def _collate_data(self, set):
Leo Gao's avatar
Leo Gao committed
39
40
        if set in self.cache:
            return self.cache[set]
Leo Gao's avatar
Leo Gao committed
41
42
43
44
45
        # One big issue with HF's implementation of this dataset: it makes a
        # separate document for each question; meanwhile, in the GPT3 paper it
        # is shown that one document is made per passage.

        r = collections.defaultdict(list)
46
        for item in datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)[set]:
Leo Gao's avatar
Leo Gao committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
            r[item['article']].append(item)
        
        res = list(r.values() >> each(lambda x: {
            'article': x[0]['article'],
            'problems': x >> each(lambda y: {
                'question': y['question'],
                'answer': y['answer'],
                'options': y['options'],
            })
        }))

        self.cache[set] = res
        return res

    def training_docs(self):
        return self._collate_data("train")

    def validation_docs(self):
        return self._collate_data("validation")

    def test_docs(self):
        return self._collate_data("test")

    def fewshot_description(self):
        # TODO: figure out description
        return ""

Jon Tow's avatar
Jon Tow committed
74
75
76
77
78
79
80
81
82
    @classmethod
    def get_answer_option(cls, problem):
        answer = cls.letter_to_num[problem['answer']]
        return problem['options'][answer]

    @classmethod
    def last_problem(cls, doc):
        return doc['problems'][-1]

83
    def doc_to_text(self, doc):
Jon Tow's avatar
Jon Tow committed
84
85
        text = 'Article: ' + doc['article'] + '\n\n'
        for problem in doc['problems'][:-1]:
Leo Gao's avatar
Leo Gao committed
86
87
88
89
90
91
            if problem['question'][-6:] == '  _  .':
                text += problem['question'][-5:] + self.get_answer_option(problem) + '\n'
            else:
                question = 'Question: ' + problem['question'] + '\n'
                answer = 'Answer: ' + self.get_answer_option(problem) + '\n'
                text += question + answer
Leo Gao's avatar
Leo Gao committed
92
        text += self.last_problem(doc)['question']
Jon Tow's avatar
Jon Tow committed
93
        return text
Leo Gao's avatar
Leo Gao committed
94

95
    def doc_to_target(self, doc):
Jon Tow's avatar
Jon Tow committed
96
        return " " + self.get_answer_option(self.last_problem(doc))
Leo Gao's avatar
Leo Gao committed
97

Leo Gao's avatar
Leo Gao committed
98
99
100
101
102
103
104
105
106
107
108
    def construct_requests(self, doc, ctx):
        """ Uses RequestFactory to construct Requests and returns an iterable of 
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural 
            language description, as well as the few shot examples, and the question
            part of the document for `doc`. 
        """
Jon Tow's avatar
Jon Tow committed
109
110
111
112
113
114
115
        problem = self.last_problem(doc)
        ll_choices = [
            rf.loglikelihood(ctx, " " + problem['options'][i])[0]
            for i in range(4)
        ]
        return ll_choices

Leo Gao's avatar
Leo Gao committed
116
117
118
119
120
121
122
123
124
125
    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a 
        dict where keys are the names of submetrics and values are the values of 
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
Jon Tow's avatar
Jon Tow committed
126
127
128
129
130
        gold = self.letter_to_num[self.last_problem(doc)['answer']]
        pred = np.argmax(results)
        return {
            "acc": int(pred == gold)
        }
Leo Gao's avatar
Leo Gao committed
131
132
133
134
135
136
137

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
            A dictionary where keys are the names of submetrics and values are 
            functions that aggregate a list of metrics
        """
Jon Tow's avatar
Jon Tow committed
138
139
140
        return {
            "acc": mean
        }
Leo Gao's avatar
Leo Gao committed
141
142
143
144
145
146
147

    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are 
            whether a higher value of the submetric is better
        """
Jon Tow's avatar
Jon Tow committed
148
149
150
        return {
            "acc": True
        }