piqa.py 1.04 KB
Newer Older
Jonathan Tow's avatar
Jonathan Tow committed
1
2
3
import numpy as np
from lm_eval.base import rf, mean
from . common import HFTask
Anish Thite's avatar
Anish Thite committed
4

Jonathan Tow's avatar
Jonathan Tow committed
5
6
7
8

class PiQA(HFTask):
    DATASET_PATH = "piqa"
    DATASET_NAME = None
Anish Thite's avatar
Anish Thite committed
9
10
11
12
13
14
15
16

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
Leo Gao's avatar
Leo Gao committed
17
        return False
Anish Thite's avatar
Anish Thite committed
18
19

    def fewshot_description(self):
Leo Gao's avatar
Leo Gao committed
20
21
        # TODO: figure out fewshot description
        return ""
Jonathan Tow's avatar
Jonathan Tow committed
22

23
    def doc_to_text(self, doc):
Jonathan Tow's avatar
Jonathan Tow committed
24
        return doc["goal"] + "\n"
Anish Thite's avatar
Anish Thite committed
25

26
    def doc_to_target(self, doc):
Jonathan Tow's avatar
Jonathan Tow committed
27
28
        solutions = [doc["sol1"], doc["sol2"]]
        return solutions[doc["label"]]
29

Leo Gao's avatar
Leo Gao committed
30
    def construct_requests(self, doc, ctx):
Jonathan Tow's avatar
Jonathan Tow committed
31
32
        ll_1, _ = rf.loglikelihood(ctx, doc['sol1'])
        ll_2, _ = rf.loglikelihood(ctx, doc['sol2'])
Leo Gao's avatar
Leo Gao committed
33
        return ll_1, ll_2
Leo Gao's avatar
Leo Gao committed
34

Jonathan Tow's avatar
Jonathan Tow committed
35
    def process_results(self, doc, results):
Leo Gao's avatar
Leo Gao committed
36
        return {
Jonathan Tow's avatar
Jonathan Tow committed
37
            'acc': np.argmax(results) == doc["label"]
Leo Gao's avatar
Leo Gao committed
38
        }
Leo Gao's avatar
Leo Gao committed
39
40

    def aggregation(self):
Leo Gao's avatar
Leo Gao committed
41
42
43
        return {
            'acc': mean
        }
Leo Gao's avatar
Leo Gao committed
44
45

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
46
47
        return {
            'acc': True
Leo Gao's avatar
Leo Gao committed
48
        }