pubmedqa.py 3.12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
"""
PubMedQA: A Dataset for Biomedical Research Question Answering
https://arxiv.org/pdf/1909.06146.pdf

PubMedQA is a novel biomedical question answering (QA) dataset collected from
PubMed abstracts. The task of PubMedQA is to answer research questions with 
yes/no/maybe (e.g.: Do preoperative statins reduce atrial fibrillation after 
coronary artery bypass grafting?) using the corresponding abstracts. PubMedQA 
has 1k expert-annotated, 61.2k unlabeled and 211.3k artificially generated QA 
instances. Each PubMedQA instance is composed of (1) a question which is either
an existing research article title or derived from one, (2) a context which is
the corresponding abstract without its conclusion, (3) a long answer, which is
the conclusion of the abstract and, presumably, answers the research question, 
and (4) a yes/no/maybe answer which summarizes the conclusion.

Homepage: https://pubmedqa.github.io/
"""
jeffhsu3's avatar
jeffhsu3 committed
18
import numpy as np
Jonathan Tow's avatar
Jonathan Tow committed
19
20
from lm_eval.base import rf, Task
from lm_eval.metrics import mean
jeffhsu3's avatar
jeffhsu3 committed
21
22


23
24
25
26
27
28
29
30
31
32
33
_CITATION = """
@inproceedings{jin2019pubmedqa,
    title={PubMedQA: A Dataset for Biomedical Research Question Answering},
    author={Jin, Qiao and Dhingra, Bhuwan and Liu, Zhengping and Cohen, William and Lu, Xinghua},
    booktitle={Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)},
    pages={2567--2577},
    year={2019}
}
"""


Jonathan Tow's avatar
Jonathan Tow committed
34
class Pubmed_QA(Task):
Leo Gao's avatar
Leo Gao committed
35
    VERSION = 0
jeffhsu3's avatar
jeffhsu3 committed
36
37
38
39
40
41
42
43
44
    DATASET_PATH = "pubmed_qa"
    DATASET_NAME = "pqa_labeled"

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False

jeffhsu3's avatar
jeffhsu3 committed
45
46
47
    def has_test_docs(self):
        return True

48
49
50
    def test_docs(self):
        if self.has_test_docs():
            # HF is labelled as train but its really just for testing
Jonathan Tow's avatar
Jonathan Tow committed
51
            return self.dataset["train"]
52

jeffhsu3's avatar
jeffhsu3 committed
53
    def doc_to_text(self, doc):
jeffhsu3's avatar
jeffhsu3 committed
54
        ctxs = "\n".join(doc["context"]["contexts"])
Leo Gao's avatar
Leo Gao committed
55
        return "Abstract: {}\nQuestion: {}\nAnswer:".format(
jeffhsu3's avatar
jeffhsu3 committed
56
            ctxs,
jeffhsu3's avatar
jeffhsu3 committed
57
58
            doc["question"],
            doc["final_decision"]
jeffhsu3's avatar
jeffhsu3 committed
59
60
        )

61
62
63
64
65
66
    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
        return doc["question"] + " " + "\n".join(doc["context"]["contexts"])

jeffhsu3's avatar
jeffhsu3 committed
67
    def doc_to_target(self, doc):
jeffhsu3's avatar
jeffhsu3 committed
68
        return " {}".format(doc["final_decision"])
jeffhsu3's avatar
jeffhsu3 committed
69
70
71
72
73
74
75
76
77
78
79

    def construct_requests(self, doc, ctx):
        """ Uses RequestFactory to construct Requests and returns
        an iterable of Requests which will be sent to the LM.
        """
        ll_yes, _ = rf.loglikelihood(ctx, " yes")
        ll_no, _ = rf.loglikelihood(ctx, " no")
        ll_maybe, _ = rf.loglikelihood(ctx, " maybe")
        return ll_yes, ll_no, ll_maybe

    def process_results(self, doc, results):
jeffhsu3's avatar
jeffhsu3 committed
80
        gold = doc["final_decision"]
jeffhsu3's avatar
jeffhsu3 committed
81
82
83
        ll_yes, ll_no, ll_maybe = results
        pred = np.argmax(results)
        return {
jeffhsu3's avatar
jeffhsu3 committed
84
            "acc": ["yes", "no", "maybe"][pred] == gold, 
jeffhsu3's avatar
jeffhsu3 committed
85
86
87
88
        }

    def aggregation(self):
        return {
jeffhsu3's avatar
jeffhsu3 committed
89
            "acc" : mean
jeffhsu3's avatar
jeffhsu3 committed
90
91
92
93
        }

    def higher_is_better(self):
        return {
jeffhsu3's avatar
jeffhsu3 committed
94
            "acc" : True
jeffhsu3's avatar
jeffhsu3 committed
95
        }