pubmedqa.py 2.02 KB
Newer Older
jeffhsu3's avatar
jeffhsu3 committed
1
import numpy as np
jeffhsu3's avatar
jeffhsu3 committed
2
import json
3
4
import random
from .common import HFTask 
&'s avatar
& committed
5
6
from lm_eval.base import rf
from ..metrics import mean
jeffhsu3's avatar
jeffhsu3 committed
7
8
9
10
11
12
13
14
15
16
17
18


class Pubmed_QA(HFTask):
    DATASET_PATH = "pubmed_qa"
    DATASET_NAME = "pqa_labeled"

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False

jeffhsu3's avatar
jeffhsu3 committed
19
20
21
    def has_test_docs(self):
        return True

22
23
24
25
26
    def test_docs(self):
        if self.has_test_docs():
            # HF is labelled as train but its really just for testing
            return self.data["train"]

jeffhsu3's avatar
jeffhsu3 committed
27
28
    def fewshot_description(self):
        # Average ctx length in labelled dataset is 238.9
jeffhsu3's avatar
jeffhsu3 committed
29
        # 2 few-shot exmamples pushes it beyond context window
jeffhsu3's avatar
jeffhsu3 committed
30
31
32
        return ""

    def doc_to_text(self, doc):
jeffhsu3's avatar
jeffhsu3 committed
33
        ctxs = "\n".join(doc["context"]["contexts"])
Leo Gao's avatar
Leo Gao committed
34
        return "Abstract: {}\nQuestion: {}\nAnswer:".format(
jeffhsu3's avatar
jeffhsu3 committed
35
            ctxs,
jeffhsu3's avatar
jeffhsu3 committed
36
37
            doc["question"],
            doc["final_decision"]
jeffhsu3's avatar
jeffhsu3 committed
38
39
40
        )

    def doc_to_target(self, doc):
jeffhsu3's avatar
jeffhsu3 committed
41
        return " {}".format(doc["final_decision"])
jeffhsu3's avatar
jeffhsu3 committed
42

43
44
45
46
47
48
    def fewshot_examples(self, k):
        # Since only test docs sample from test docs
        if self._training_docs is None:
            self._training_docs = list(self.test_docs())
        return random.sample(self._training_docs, k)

jeffhsu3's avatar
jeffhsu3 committed
49
50
51
52
53
54
55
56
57
58
    def construct_requests(self, doc, ctx):
        """ Uses RequestFactory to construct Requests and returns
        an iterable of Requests which will be sent to the LM.
        """
        ll_yes, _ = rf.loglikelihood(ctx, " yes")
        ll_no, _ = rf.loglikelihood(ctx, " no")
        ll_maybe, _ = rf.loglikelihood(ctx, " maybe")
        return ll_yes, ll_no, ll_maybe

    def process_results(self, doc, results):
jeffhsu3's avatar
jeffhsu3 committed
59
        gold = doc["final_decision"]
jeffhsu3's avatar
jeffhsu3 committed
60
61
62
        ll_yes, ll_no, ll_maybe = results
        pred = np.argmax(results)
        return {
jeffhsu3's avatar
jeffhsu3 committed
63
            "acc": ["yes", "no", "maybe"][pred] == gold, 
jeffhsu3's avatar
jeffhsu3 committed
64
65
66
67
        }

    def aggregation(self):
        return {
jeffhsu3's avatar
jeffhsu3 committed
68
            "acc" : mean
jeffhsu3's avatar
jeffhsu3 committed
69
70
71
72
        }

    def higher_is_better(self):
        return {
jeffhsu3's avatar
jeffhsu3 committed
73
            "acc" : True
jeffhsu3's avatar
jeffhsu3 committed
74
        }