pubmedqa.py 3.59 KB
Newer Older
jeffhsu3's avatar
jeffhsu3 committed
1
2
3
"""
"""

jeffhsu3's avatar
jeffhsu3 committed
4
import os
jeffhsu3's avatar
jeffhsu3 committed
5
import numpy as np
jeffhsu3's avatar
jeffhsu3 committed
6
import json
jeffhsu3's avatar
jeffhsu3 committed
7
from ..utils import sh
jeffhsu3's avatar
jeffhsu3 committed
8
9
10
from .common import HFTask, yesno
from lm_eval.base import MultipleChoiceTask, rf, mean
import zipfile
jeffhsu3's avatar
jeffhsu3 committed
11
12
13
14
15
16
17
18
19
20
21
22


class Pubmed_QA(HFTask):
    DATASET_PATH = "pubmed_qa"
    DATASET_NAME = "pqa_labeled"

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False

jeffhsu3's avatar
jeffhsu3 committed
23
24
25
    def has_test_docs(self):
        return True

jeffhsu3's avatar
jeffhsu3 committed
26
27
    def fewshot_description(self):
        # Average ctx length in labelled dataset is 238.9
jeffhsu3's avatar
jeffhsu3 committed
28
        # 2 few-shot exmamples pushes it beyond context window
jeffhsu3's avatar
jeffhsu3 committed
29
30
31
        return ""

    def doc_to_text(self, doc):
jeffhsu3's avatar
jeffhsu3 committed
32
        ctxs = "\n".join(doc["context"]["contexts"])
jeffhsu3's avatar
jeffhsu3 committed
33
34
        return "abstract: {}\nquestion: {}\nanswer:".format(
            ctxs,
jeffhsu3's avatar
jeffhsu3 committed
35
36
            doc["question"],
            doc["final_decision"]
jeffhsu3's avatar
jeffhsu3 committed
37
38
39
        )

    def doc_to_target(self, doc):
jeffhsu3's avatar
jeffhsu3 committed
40
        return " {}".format(doc["final_decision"])
jeffhsu3's avatar
jeffhsu3 committed
41
42
43
44
45
46
47
48
49
50
51

    def construct_requests(self, doc, ctx):
        """ Uses RequestFactory to construct Requests and returns
        an iterable of Requests which will be sent to the LM.
        """
        ll_yes, _ = rf.loglikelihood(ctx, " yes")
        ll_no, _ = rf.loglikelihood(ctx, " no")
        ll_maybe, _ = rf.loglikelihood(ctx, " maybe")
        return ll_yes, ll_no, ll_maybe

    def process_results(self, doc, results):
jeffhsu3's avatar
jeffhsu3 committed
52
        gold = doc["final_decision"]
jeffhsu3's avatar
jeffhsu3 committed
53
54
55
        ll_yes, ll_no, ll_maybe = results
        pred = np.argmax(results)
        return {
jeffhsu3's avatar
jeffhsu3 committed
56
            "acc": ["yes", "no", "maybe"][pred] == gold, 
jeffhsu3's avatar
jeffhsu3 committed
57
58
59
60
        }

    def aggregation(self):
        return {
jeffhsu3's avatar
jeffhsu3 committed
61
            "acc" : mean
jeffhsu3's avatar
jeffhsu3 committed
62
63
64
65
        }

    def higher_is_better(self):
        return {
jeffhsu3's avatar
jeffhsu3 committed
66
            "acc" : True
jeffhsu3's avatar
jeffhsu3 committed
67
        }
jeffhsu3's avatar
jeffhsu3 committed
68

jeffhsu3's avatar
sciq  
jeffhsu3 committed
69
70
71
72
73
    def test_docs(self):
        if self.has_test_docs():
            # HF is labelled as train but its really just for testing
            return self.data["train"]

jeffhsu3's avatar
jeffhsu3 committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

class SciQ(MultipleChoiceTask):
    def download(self):
        if not os.path.exists('data/sciq'):
            os.mkdir('data/sciq')
            sh((
                "wget https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip -O data/sciq/SciQ.zip"
            ))
            with zipfile.ZipFile("data/sciq/SciQ.zip", "r") as zf:
                zf.extractall("data/sciq/")

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

jeffhsu3's avatar
sciq  
jeffhsu3 committed
94
    def _convert_standard(self, doc):
jeffhsu3's avatar
jeffhsu3 committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        choices = [
            doc["distractor1"], 
            doc["distractor2"], 
            doc["distractor3"],
            doc["correct_answer"],
        ]
        src = doc['support']
        out_doc = {
            "source" : src,
            "query" : doc['question'],
            "choices" : choices,
            "gold" : 3,
        }
        return out_doc
    
    def load_docs(self, textfilename):
jeffhsu3's avatar
sciq  
jeffhsu3 committed
111
112
        with open(textfilename, 'r') as j:
            docs = json.loads(j.read()) 
jeffhsu3's avatar
jeffhsu3 committed
113
        for record in docs:
jeffhsu3's avatar
sciq  
jeffhsu3 committed
114
            yield self._convert_standard(record)
jeffhsu3's avatar
jeffhsu3 committed
115
116
117
118
119
120
121

    def fewshot_description(self):
        # Average ctx length in labelled dataset is 238.9
        # 2 few-shot exmamples pushes it beyond context window
        return ""

    def training_docs(self):
jeffhsu3's avatar
sciq  
jeffhsu3 committed
122
        return self.load_docs("data/sciq/SciQ dataset-2 3/train.json")
jeffhsu3's avatar
jeffhsu3 committed
123
124

    def validation_docs(self):
jeffhsu3's avatar
sciq  
jeffhsu3 committed
125
        return self.load_docs("data/sciq/SciQ dataset-2 3/valid.json")
jeffhsu3's avatar
jeffhsu3 committed
126
127

    def test_docs(self):
jeffhsu3's avatar
sciq  
jeffhsu3 committed
128
        return self.load_docs("data/sciq/SciQ dataset-2 3/test.json")
jeffhsu3's avatar
jeffhsu3 committed
129
130
131
132
133
134
135

    def doc_to_text(self, doc):
        return " {}\n{}".format(doc["source"], doc["query"])

class EmrQA():
    def load_docs(self, textfilename):
        pass