pubmedqa.py 3.46 KB
Newer Older
jeffhsu3's avatar
jeffhsu3 committed
1
2
3
"""
"""

jeffhsu3's avatar
jeffhsu3 committed
4
import os
jeffhsu3's avatar
jeffhsu3 committed
5
import numpy as np
jeffhsu3's avatar
jeffhsu3 committed
6
import json
jeffhsu3's avatar
jeffhsu3 committed
7
from ..utils import sh
jeffhsu3's avatar
jeffhsu3 committed
8
9
10
from .common import HFTask, yesno
from lm_eval.base import MultipleChoiceTask, rf, mean
import zipfile
jeffhsu3's avatar
jeffhsu3 committed
11
12
13
14
15
16
17
18
19
20
21
22


class Pubmed_QA(HFTask):
    DATASET_PATH = "pubmed_qa"
    DATASET_NAME = "pqa_labeled"

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False

jeffhsu3's avatar
jeffhsu3 committed
23
24
25
    def has_test_docs(self):
        return True

jeffhsu3's avatar
jeffhsu3 committed
26
27
    def fewshot_description(self):
        # Average ctx length in labelled dataset is 238.9
jeffhsu3's avatar
jeffhsu3 committed
28
        # 2 few-shot exmamples pushes it beyond context window
jeffhsu3's avatar
jeffhsu3 committed
29
30
31
        return ""

    def doc_to_text(self, doc):
jeffhsu3's avatar
jeffhsu3 committed
32
        ctxs = "\n".join(doc["context"]["contexts"])
jeffhsu3's avatar
jeffhsu3 committed
33
34
        return "abstract: {}\nquestion: {}\nanswer:".format(
            ctxs,
jeffhsu3's avatar
jeffhsu3 committed
35
36
            doc["question"],
            doc["final_decision"]
jeffhsu3's avatar
jeffhsu3 committed
37
38
39
        )

    def doc_to_target(self, doc):
jeffhsu3's avatar
jeffhsu3 committed
40
        return " {}".format(doc["final_decision"])
jeffhsu3's avatar
jeffhsu3 committed
41
42
43
44
45
46
47
48
49
50
51

    def construct_requests(self, doc, ctx):
        """ Uses RequestFactory to construct Requests and returns
        an iterable of Requests which will be sent to the LM.
        """
        ll_yes, _ = rf.loglikelihood(ctx, " yes")
        ll_no, _ = rf.loglikelihood(ctx, " no")
        ll_maybe, _ = rf.loglikelihood(ctx, " maybe")
        return ll_yes, ll_no, ll_maybe

    def process_results(self, doc, results):
jeffhsu3's avatar
jeffhsu3 committed
52
        gold = doc["final_decision"]
jeffhsu3's avatar
jeffhsu3 committed
53
54
55
        ll_yes, ll_no, ll_maybe = results
        pred = np.argmax(results)
        return {
jeffhsu3's avatar
jeffhsu3 committed
56
            "acc": ["yes", "no", "maybe"][pred] == gold, 
jeffhsu3's avatar
jeffhsu3 committed
57
58
59
60
        }

    def aggregation(self):
        return {
jeffhsu3's avatar
jeffhsu3 committed
61
            "acc" : mean
jeffhsu3's avatar
jeffhsu3 committed
62
63
64
65
        }

    def higher_is_better(self):
        return {
jeffhsu3's avatar
jeffhsu3 committed
66
            "acc" : True
jeffhsu3's avatar
jeffhsu3 committed
67
        }
jeffhsu3's avatar
jeffhsu3 committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131


class SciQ(MultipleChoiceTask):
    def download(self):
        if not os.path.exists('data/sciq'):
            os.mkdir('data/sciq')
            sh((
                "wget https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip -O data/sciq/SciQ.zip"
            ))
            with zipfile.ZipFile("data/sciq/SciQ.zip", "r") as zf:
                zf.extractall("data/sciq/")

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

    def _convert_strandard(doc):
        choices = [
            doc["distractor1"], 
            doc["distractor2"], 
            doc["distractor3"],
            doc["correct_answer"],
        ]
        src = doc['support']
        out_doc = {
            "source" : src,
            "query" : doc['question'],
            "choices" : choices,
            "gold" : 3,
        }
        return out_doc
    
    def load_docs(self, textfilename):
        if labelfilename != None:
            with open(textfilename, 'r') as j:
                docs = json.loads(j.read()) 
        for record in docs:
            yield _convert_standard(record)

    def fewshot_description(self):
        # Average ctx length in labelled dataset is 238.9
        # 2 few-shot exmamples pushes it beyond context window
        return ""

    def training_docs(self):
        return self.load_docs("data/sciq/Sci-Q\ dataset-2\ 3/train.json")

    def validation_docs(self):
        return self.load_docs("data/sciq/Sci-Q\ dataset-2\ 3/valid.json")

    def test_docs(self):
        return self.load_docs("data/sciq/Sci-Q\ dataset-2\ 3/test.json")

    def doc_to_text(self, doc):
        return " {}\n{}".format(doc["source"], doc["query"])

class EmrQA():
    def load_docs(self, textfilename):
        pass