hendrycks_test.py 4.6 KB
Newer Older
Andy Zou's avatar
Andy Zou committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from lm_eval.base import Task
import os
import csv
import numpy as np
from ..utils import sh

SUBJECTS = ['abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 
            'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 
            'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 
            'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 
            'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 
            'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 
            'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 
            'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 
            'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 
            'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions']

CHOICES = ['A','B','C','D']

def create_all_tasks():
    """Creates a dictionary of tasks from a list of subjects
    :return: {task_name: task}
Andy Zou's avatar
Andy Zou committed
23
        e.g. {hendrycksTest-abstract_algebra: Task, hendrycksTest-anatomy: Task}
Andy Zou's avatar
Andy Zou committed
24
25
    """
    return {
Andy Zou's avatar
Andy Zou committed
26
        f"hendrycksTest-{sub}": create_task(sub) for sub in SUBJECTS
Andy Zou's avatar
Andy Zou committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    }

def create_task(subject):
    class HendrycksTest(GeneralHendrycksTest):
        def __init__(self):
            super().__init__(subject)
    return HendrycksTest

class GeneralHendrycksTest(Task):

    def __init__(self, subject):
        self.subject = subject
        super().__init__()

    def download(self):
        
Andy Zou's avatar
Andy Zou committed
43
        self.data_dir = "data/hendrycksTest/"
Andy Zou's avatar
Andy Zou committed
44
45
46
47
48
49
        if not os.path.exists(self.data_dir):
            sh("""
                mkdir -p data
                wget https://people.eecs.berkeley.edu/~hendrycks/data.tar -P data/
                tar -xf data/data.tar -C data/
                rm data/data.tar
Andy Zou's avatar
Andy Zou committed
50
                mv data/data data/hendrycksTest
Andy Zou's avatar
Andy Zou committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
                """)

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

    def _load_docs(self, split):

        filename = os.path.join(self.data_dir, split, self.subject + f"_{split}.csv")
        reader = csv.reader(open(filename, 'r'), quotechar='"', delimiter=',')

        docs = []
        for row in reader:
            doc = {
                "query": self._format_example(row),
                "choices": CHOICES,
                "gold": CHOICES.index(row[-1])
            }
            docs.append(doc)
        return docs

    def _format_example(self, row):
        """
            <prompt>
            A. <choice1>
            B. <choice2>
            C. <choice3>
            D. <choice4>
            Answer:
        """
        prompt = row[0]
        k = len(row) - 2
        for j in range(k):
            prompt += "\n{}. {}".format(CHOICES[j], row[j+1])
        prompt += "\nAnswer:"
        return prompt
        
    def training_docs(self):
        raise NotImplementedError

    def validation_docs(self):
        return self._load_docs("val")

    def test_docs(self):
        return self._load_docs("test")

    def doc_to_text(self, doc):
        return doc["query"]

    def doc_to_target(self, doc):
        return " " + doc["answer"]

    def fewshot_docs(self, k):
        assert k >= 5, "Maximum 5 few shot examples."
        return self._load_docs('dev')[:k]

    def fewshot_description(self):
        subject = self.subject.replace("_", " ")
        return f"The following are multiple choice questions (with answers) about {subject}.\n\n"

    def fewshot_context(self, doc, num_fewshot, provide_description):
        raw_description = self.fewshot_description()
        description = raw_description if provide_description else ""

        if num_fewshot == 0:
            labeled_examples = ""
        else:
            # TODO: crop if over max_len
            labeled_examples = "\n\n".join(
                [self.doc_to_text(doc) + self.doc_to_target(doc) for doc in self.fewshot_docs(k=num_fewshot)]
            ) + "\n\n"

        example = self.doc_to_text(doc)
        return description + labeled_examples + example