TyDiQA.py 3.22 KB
Newer Older
KhalidAlt's avatar
KhalidAlt committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Right for the Wrong Reasons: Diagnosing Syntactic Heuristics in Natural Language Inference
https://arxiv.org/abs/1902.01007

A controlled evaluation set called HANS (Heuristic Analysis for NLI Systems),
which contains many examples where the heuristics fail.

Homepage: https://github.com/tommccoy1/hans
"""
from lm_eval.base import PromptSourceTask


_CITATION = """\
@article{tydiqa,
title   = {TyDi QA: A Benchmark for Information-Seeking Question Answering in Typologically Diverse Languages},
author  = {Jonathan H. Clark and Eunsol Choi and Michael Collins and Dan Garrette and Tom Kwiatkowski and Vitaly Nikolaev and Jennimaria Palomaki}
year    = {2020},
journal = {Transactions of the Association for Computational Linguistics}
}
"""


class Primary(PromptSourceTask):
    VERSION = 0
    DATASET_PATH = "tydiqa"
    DATASET_NAME = "primary_task"

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
        if self.has_training_docs():
            # We cache training documents in `self._training_docs` for faster
            # few-shot processing. If the data is too large to fit in memory,
            # return the training data as a generator instead of a list.
            if self._training_docs is None:
                self._training_docs = list(self.dataset["train"])
            return self._training_docs

    def validation_docs(self):
        if self.has_validation_docs():
            return self.dataset["validation"]

    def test_docs(self):
        if self.has_test_docs():
            return self.dataset["test"]
    def process_results(self, doc, results):
        out = {}
        #gold = doc
        pred = results[0].strip()
        print("############")
        print(self.doc_to_target(doc))

        target = self.doc_to_target(doc)['sub_label']
        #pred = np.argmax(results)
        out["acc"] = pred == target


        #result = metric.compute(predictions=pred, references=gold)
        #out['acc'] = {"accuracy": result["score"]}
        
        #out['acc'] = 1.0 if pred == gold else 0.0
        if self.save_examples:
            example = {
                "pred": pred,
                "target": target,
            }
            return out, example

        return out


class Secondary(PromptSourceTask):
    VERSION = 0
    DATASET_PATH = "tydiqa"
    DATASET_NAME = "secondary_task"

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
        if self.has_training_docs():
            # We cache training documents in `self._training_docs` for faster
            # few-shot processing. If the data is too large to fit in memory,
            # return the training data as a generator instead of a list.
            if self._training_docs is None:
                self._training_docs = list(self.dataset["train"])
            return self._training_docs

    def validation_docs(self):
        if self.has_validation_docs():
            return self.dataset["validation"]

    def test_docs(self):
        if self.has_test_docs():
            return self.dataset["test"]