xquad.py 3.7 KB
Newer Older
sdtblck's avatar
sdtblck committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from squad import SQuAD2
from math import exp
from functools import partial
import datasets


def _squad_metric(predictions, references):
    squad_metric = datasets.load_metric("squad")
    return squad_metric.compute(predictions=predictions, references=references)


def _squad_agg(key, items):
    predictions, references = zip(*items)
    return _squad_metric(predictions=predictions, references=references)[key]


class XQuADBase(SQuAD2):
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = None

    def has_training_docs(self):
        return False

    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
        continuation, (logprob_unanswerable, _) = results

        no_answer_probability = exp(logprob_unanswerable)

        predictions = {
            'id': doc['id'],
            'prediction_text': continuation,
            'no_answer_probability': no_answer_probability,
        }

        references = {
            'id': doc['id'],
            'answers': doc['answers'],
        }

        return {
            'exact': (predictions, references),  # Exact match (the normalized answer exactly match the gold answer)
            'f1': (predictions, references),  # The F-score of predicted tokens versus the gold answer
        }

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
            A dictionary where keys are the names of submetrics and values are
            functions that aggregate a list of metrics
        """
        return {
            'exact': partial(_squad_agg, 'exact'),  # Exact match (the normalized answer exactly match the gold answer)
            'f1': partial(_squad_agg, 'f1'),  # The F-score of predicted tokens versus the gold answer
        }

    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are
            whether a higher value of the submetric is better
        """
        return {
            'exact': True,  # Exact match (the normalized answer exactly match the gold answer)
            'f1': True,  # The F-score of predicted tokens versus the gold answer
        }


class XQuADAr(XQuADBase):
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.ar'


class XQuADDe(XQuADBase):
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.de'


class XQuADZh(XQuADBase):
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.zh'


class XQuADVi(XQuADBase):
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.vi'


class XQuADEn(XQuADBase):
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.en'


class XQuADEs(XQuADBase):
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.es'


class XQuADHi(XQuADBase):
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.hi'


class XQuADEl(XQuADBase):
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.el'


class XQuADTh(XQuADBase):
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.th'


class XQuADTr(XQuADBase):
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.tr'


class XQuADRu(XQuADBase):
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.ru'


class XQuADRo(XQuADBase):
    VERSION = 0
    DATASET_PATH = "xquad"
    DATASET_NAME = 'xquad.ro'