xquad.py 5.75 KB
Newer Older
sdtblck's avatar
sdtblck committed
1
from .squad import SQuAD2
sdtblck's avatar
sdtblck committed
2
3
4
from math import exp
from functools import partial
import datasets
sdtblck's avatar
sdtblck committed
5
from lm_eval.base import rf
sdtblck's avatar
sdtblck committed
6
7
8
9
10
11
12
13
14


def _squad_metric(predictions, references):
    squad_metric = datasets.load_metric("squad")
    return squad_metric.compute(predictions=predictions, references=references)


def _squad_agg(key, items):
    predictions, references = zip(*items)
sdtblck's avatar
sdtblck committed
15
16
17
    for prediction in predictions:
        if isinstance(prediction['prediction_text'], list):
            prediction['prediction_text'] = prediction['prediction_text'][0]
sdtblck's avatar
sdtblck committed
18
19
20
21
22
23
24
    return _squad_metric(predictions=predictions, references=references)[key]


class XQuADBase(SQuAD2):
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = None
sdtblck's avatar
sdtblck committed
25
26
27
    BACKGROUND = "Background:"
    QUESTION = "Question:"
    ANSWER = "Answer:"
sdtblck's avatar
sdtblck committed
28
29
30

    def has_training_docs(self):
        return False
sdtblck's avatar
sdtblck committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44


    def construct_requests(self, doc, ctx):
        """ Uses RequestFactory to construct Requests and returns an iterable of 
        Requests which will be sent to the LM.
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural 
            language description, as well as the few shot examples, and the question
            part of the document for `doc`. 
        """
        continuation = rf.greedy_until(ctx, ['\n'])
        return continuation
sdtblck's avatar
sdtblck committed
45
46
    
    def doc_to_text(self, doc):
sdtblck's avatar
sdtblck committed
47
48
        text = self.BACKGROUND + '\n\n' + doc['context'] + '\n\n' + self.QUESTION + doc['question'] + '\n\n' + \
               self.ANSWER
sdtblck's avatar
sdtblck committed
49
        return text
sdtblck's avatar
sdtblck committed
50
51
52
53
54
55
56
57
58
59
60

    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
sdtblck's avatar
sdtblck committed
61
        continuation = results
sdtblck's avatar
sdtblck committed
62
63
64
65
66
67
68
69
70
71
72
73

        predictions = {
            'id': doc['id'],
            'prediction_text': continuation,
        }

        references = {
            'id': doc['id'],
            'answers': doc['answers'],
        }

        return {
sdtblck's avatar
sdtblck committed
74
            'exact_match': (predictions, references),  # Exact match (the normalized answer exactly match the gold answer)
sdtblck's avatar
sdtblck committed
75
76
77
78
79
80
81
82
83
84
            'f1': (predictions, references),  # The F-score of predicted tokens versus the gold answer
        }

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
            A dictionary where keys are the names of submetrics and values are
            functions that aggregate a list of metrics
        """
        return {
sdtblck's avatar
sdtblck committed
85
            'exact_match': partial(_squad_agg, 'exact_match'),  # Exact match (the normalized answer exactly match the gold answer)
sdtblck's avatar
sdtblck committed
86
87
88
89
90
91
92
93
94
95
            'f1': partial(_squad_agg, 'f1'),  # The F-score of predicted tokens versus the gold answer
        }

    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are
            whether a higher value of the submetric is better
        """
        return {
sdtblck's avatar
sdtblck committed
96
            'exact_match': True,  # Exact match (the normalized answer exactly match the gold answer)
sdtblck's avatar
sdtblck committed
97
98
99
100
            'f1': True,  # The F-score of predicted tokens versus the gold answer
        }


sdtblck's avatar
sdtblck committed
101
class XQuADAr(XQuADBase): # arabic
sdtblck's avatar
sdtblck committed
102
103
104
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.ar'
sdtblck's avatar
sdtblck committed
105
106
107
    BACKGROUND = ":معرفتي"
    QUESTION = ":سؤال"
    ANSWER = ":إجابه"
sdtblck's avatar
sdtblck committed
108
109


sdtblck's avatar
sdtblck committed
110
class XQuADDe(XQuADBase): # german
sdtblck's avatar
sdtblck committed
111
112
113
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.de'
sdtblck's avatar
sdtblck committed
114
115
    BACKGROUND = "Hintergrund:"
    QUESTION = "Frage:"
116
    ANSWER = "Antwort:"
sdtblck's avatar
sdtblck committed
117
118


sdtblck's avatar
sdtblck committed
119
class XQuADZh(XQuADBase): # chinese
sdtblck's avatar
sdtblck committed
120
121
122
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.zh'
sdtblck's avatar
sdtblck committed
123
124
    BACKGROUND = "背景:"
    QUESTION = "問題:"
125
    ANSWER = "答案:"
sdtblck's avatar
sdtblck committed
126
127


sdtblck's avatar
sdtblck committed
128
class XQuADVi(XQuADBase): # vietnamese
sdtblck's avatar
sdtblck committed
129
130
131
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.vi'
sdtblck's avatar
sdtblck committed
132
133
134
    BACKGROUND = "lý lịch:"
    QUESTION = "câu hỏi:"
    ANSWER = "câu trả lời:"
sdtblck's avatar
sdtblck committed
135
136


sdtblck's avatar
sdtblck committed
137
class XQuADEn(XQuADBase): # english
sdtblck's avatar
sdtblck committed
138
139
140
141
142
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.en'


sdtblck's avatar
sdtblck committed
143
class XQuADEs(XQuADBase): # spanish
sdtblck's avatar
sdtblck committed
144
145
146
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.es'
sdtblck's avatar
sdtblck committed
147
148
149
    BACKGROUND = "antecedentes:"
    QUESTION = "pregunta:"
    ANSWER = "respuesta:"
sdtblck's avatar
sdtblck committed
150
151


sdtblck's avatar
sdtblck committed
152
class XQuADHi(XQuADBase): # hindi
sdtblck's avatar
sdtblck committed
153
154
155
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.hi'
sdtblck's avatar
sdtblck committed
156
157
158
    BACKGROUND = "पृष्ठभूमि:"
    QUESTION = "सवाल:"
    ANSWER = "उत्तर:"
sdtblck's avatar
sdtblck committed
159
160


sdtblck's avatar
sdtblck committed
161
class XQuADEl(XQuADBase): # greek
sdtblck's avatar
sdtblck committed
162
163
164
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.el'
sdtblck's avatar
sdtblck committed
165
    BACKGROUND = "Ιστορικό:"
166
167
    QUESTION = "Ερώτηση:"
    ANSWER = "Απάντηση:"
sdtblck's avatar
sdtblck committed
168
169


sdtblck's avatar
sdtblck committed
170
class XQuADTh(XQuADBase): # thai
sdtblck's avatar
sdtblck committed
171
172
173
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.th'
sdtblck's avatar
sdtblck committed
174
175
176
    BACKGROUND = "พื้นหลัง:"
    QUESTION = "คำถาม:"
    ANSWER = "ตอบ:"
sdtblck's avatar
sdtblck committed
177
178


sdtblck's avatar
sdtblck committed
179
class XQuADTr(XQuADBase): # turkish
sdtblck's avatar
sdtblck committed
180
181
182
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.tr'
sdtblck's avatar
sdtblck committed
183
184
185
186
    BACKGROUND = "arka fon:"
    QUESTION = "soru:"
    ANSWER = "Cevap:"
    
sdtblck's avatar
sdtblck committed
187

sdtblck's avatar
sdtblck committed
188
class XQuADRu(XQuADBase): # russian
sdtblck's avatar
sdtblck committed
189
190
191
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.ru'
sdtblck's avatar
sdtblck committed
192
193
194
    BACKGROUND = "задний план:"
    QUESTION = "вопрос:"
    ANSWER = "отвечать:"
sdtblck's avatar
sdtblck committed
195
196


sdtblck's avatar
sdtblck committed
197
class XQuADRo(XQuADBase): # romanian
sdtblck's avatar
sdtblck committed
198
199
200
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.ro'
sdtblck's avatar
sdtblck committed
201
202
203
    BACKGROUND = "fundal:"
    QUESTION = "întrebare:"
    ANSWER = "Răspuns:"