xquad.py 5.08 KB
Newer Older
sdtblck's avatar
sdtblck committed
1
from .squad import SQuAD2
sdtblck's avatar
sdtblck committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from math import exp
from functools import partial
import datasets


def _squad_metric(predictions, references):
    squad_metric = datasets.load_metric("squad")
    return squad_metric.compute(predictions=predictions, references=references)


def _squad_agg(key, items):
    predictions, references = zip(*items)
    return _squad_metric(predictions=predictions, references=references)[key]


class XQuADBase(SQuAD2):
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = None
sdtblck's avatar
sdtblck committed
21
22
23
    BACKGROUND = "Background:"
    QUESTION = "Question:"
    ANSWER = "Answer:"
sdtblck's avatar
sdtblck committed
24
25
26

    def has_training_docs(self):
        return False
sdtblck's avatar
sdtblck committed
27
28
29
30
31
    
    def doc_to_text(self, doc):
        text = ""
        text = text + self.BACKGROUND + '\n\n' + doc['context'] + '\n\n' + self.QUESTION + doc['question'] + '\n\n' + self.ANSWER
        return text
sdtblck's avatar
sdtblck committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
        continuation, (logprob_unanswerable, _) = results

        no_answer_probability = exp(logprob_unanswerable)

        predictions = {
            'id': doc['id'],
            'prediction_text': continuation,
            'no_answer_probability': no_answer_probability,
        }

        references = {
            'id': doc['id'],
            'answers': doc['answers'],
        }

        return {
            'exact': (predictions, references),  # Exact match (the normalized answer exactly match the gold answer)
            'f1': (predictions, references),  # The F-score of predicted tokens versus the gold answer
        }

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
            A dictionary where keys are the names of submetrics and values are
            functions that aggregate a list of metrics
        """
        return {
            'exact': partial(_squad_agg, 'exact'),  # Exact match (the normalized answer exactly match the gold answer)
            'f1': partial(_squad_agg, 'f1'),  # The F-score of predicted tokens versus the gold answer
        }

    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are
            whether a higher value of the submetric is better
        """
        return {
            'exact': True,  # Exact match (the normalized answer exactly match the gold answer)
            'f1': True,  # The F-score of predicted tokens versus the gold answer
        }


sdtblck's avatar
sdtblck committed
86
class XQuADAr(XQuADBase): # arabic
sdtblck's avatar
sdtblck committed
87
88
89
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.ar'
sdtblck's avatar
sdtblck committed
90
91
92
    BACKGROUND = ":معرفتي"
    QUESTION = ":سؤال"
    ANSWER = ":إجابه"
sdtblck's avatar
sdtblck committed
93
94


sdtblck's avatar
sdtblck committed
95
class XQuADDe(XQuADBase): # german
sdtblck's avatar
sdtblck committed
96
97
98
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.de'
sdtblck's avatar
sdtblck committed
99
100
101
    BACKGROUND = "Hintergrund:"
    QUESTION = "Frage:"
    ANSWER = "Antworten:"
sdtblck's avatar
sdtblck committed
102
103


sdtblck's avatar
sdtblck committed
104
class XQuADZh(XQuADBase): # chinese
sdtblck's avatar
sdtblck committed
105
106
107
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.zh'
sdtblck's avatar
sdtblck committed
108
109
110
    BACKGROUND = "背景:"
    QUESTION = "問題:"
    ANSWER = "回答:"
sdtblck's avatar
sdtblck committed
111
112


sdtblck's avatar
sdtblck committed
113
class XQuADVi(XQuADBase): # vietnamese
sdtblck's avatar
sdtblck committed
114
115
116
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.vi'
sdtblck's avatar
sdtblck committed
117
118
119
    BACKGROUND = "lý lịch:"
    QUESTION = "câu hỏi:"
    ANSWER = "câu trả lời:"
sdtblck's avatar
sdtblck committed
120
121


sdtblck's avatar
sdtblck committed
122
class XQuADEn(XQuADBase): # english
sdtblck's avatar
sdtblck committed
123
124
125
126
127
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.en'


sdtblck's avatar
sdtblck committed
128
class XQuADEs(XQuADBase): # spanish
sdtblck's avatar
sdtblck committed
129
130
131
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.es'
sdtblck's avatar
sdtblck committed
132
133
134
    BACKGROUND = "antecedentes:"
    QUESTION = "pregunta:"
    ANSWER = "respuesta:"
sdtblck's avatar
sdtblck committed
135
136


sdtblck's avatar
sdtblck committed
137
class XQuADHi(XQuADBase): # hindi
sdtblck's avatar
sdtblck committed
138
139
140
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.hi'
sdtblck's avatar
sdtblck committed
141
142
143
    BACKGROUND = "पृष्ठभूमि:"
    QUESTION = "सवाल:"
    ANSWER = "उत्तर:"
sdtblck's avatar
sdtblck committed
144
145


sdtblck's avatar
sdtblck committed
146
class XQuADEl(XQuADBase): # greek
sdtblck's avatar
sdtblck committed
147
148
149
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.el'
sdtblck's avatar
sdtblck committed
150
151
152
    BACKGROUND = "Ιστορικό:"
    QUESTION = "ερώτηση:"
    ANSWER = "απάντηση:"
sdtblck's avatar
sdtblck committed
153
154


sdtblck's avatar
sdtblck committed
155
class XQuADTh(XQuADBase): # thai
sdtblck's avatar
sdtblck committed
156
157
158
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.th'
sdtblck's avatar
sdtblck committed
159
160
161
    BACKGROUND = "พื้นหลัง:"
    QUESTION = "คำถาม:"
    ANSWER = "ตอบ:"
sdtblck's avatar
sdtblck committed
162
163


sdtblck's avatar
sdtblck committed
164
class XQuADTr(XQuADBase): # turkish
sdtblck's avatar
sdtblck committed
165
166
167
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.tr'
sdtblck's avatar
sdtblck committed
168
169
170
171
    BACKGROUND = "arka fon:"
    QUESTION = "soru:"
    ANSWER = "Cevap:"
    
sdtblck's avatar
sdtblck committed
172

sdtblck's avatar
sdtblck committed
173
class XQuADRu(XQuADBase): # russian
sdtblck's avatar
sdtblck committed
174
175
176
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.ru'
sdtblck's avatar
sdtblck committed
177
178
179
    BACKGROUND = "задний план:"
    QUESTION = "вопрос:"
    ANSWER = "отвечать:"
sdtblck's avatar
sdtblck committed
180
181


sdtblck's avatar
sdtblck committed
182
class XQuADRo(XQuADBase): # romanian
sdtblck's avatar
sdtblck committed
183
184
185
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.ro'
sdtblck's avatar
sdtblck committed
186
187
188
    BACKGROUND = "fundal:"
    QUESTION = "întrebare:"
    ANSWER = "Răspuns:"