"src/targets/gpu/hiprtc/CMakeLists.txt" did not exist on "661046c6777db12922c711657ec35b959833fcc5"
hendrycks.py 4.58 KB
Newer Older
Andy Zou's avatar
Andy Zou committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from lm_eval.base import Task
import os
import csv
import numpy as np
from ..utils import sh

SUBJECTS = ['abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 
            'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 
            'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 
            'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 
            'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 
            'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 
            'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 
            'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 
            'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 
            'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions']

CHOICES = ['A','B','C','D']

def create_all_tasks():
    """Creates a dictionary of tasks from a list of subjects
    :return: {task_name: task}
        e.g. {hendrycks-abstract_algebra: Task, hendrycks-anatomy: Task}
    """
    return {
        f"hendrycks-{sub}": create_task(sub) for sub in SUBJECTS
    }

def create_task(subject):
    class HendrycksTest(GeneralHendrycksTest):
        def __init__(self):
            super().__init__(subject)
    return HendrycksTest

class GeneralHendrycksTest(Task):

    def __init__(self, subject):
        self.subject = subject
        super().__init__()

    def download(self):
        
        self.data_dir = "data/hendrycks/"
        if not os.path.exists(self.data_dir):
            sh("""
                mkdir -p data
                wget https://people.eecs.berkeley.edu/~hendrycks/data.tar -P data/
                tar -xf data/data.tar -C data/
                rm data/data.tar
                mv data/data data/hendrycks
                """)

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

    def _load_docs(self, split):

        filename = os.path.join(self.data_dir, split, self.subject + f"_{split}.csv")
        reader = csv.reader(open(filename, 'r'), quotechar='"', delimiter=',')

        docs = []
        for row in reader:
            doc = {
                "query": self._format_example(row),
                "choices": CHOICES,
                "gold": CHOICES.index(row[-1])
            }
            docs.append(doc)
        return docs

    def _format_example(self, row):
        """
            <prompt>
            A. <choice1>
            B. <choice2>
            C. <choice3>
            D. <choice4>
            Answer:
        """
        prompt = row[0]
        k = len(row) - 2
        for j in range(k):
            prompt += "\n{}. {}".format(CHOICES[j], row[j+1])
        prompt += "\nAnswer:"
        return prompt
        
    def training_docs(self):
        raise NotImplementedError

    def validation_docs(self):
        return self._load_docs("val")

    def test_docs(self):
        return self._load_docs("test")

    def doc_to_text(self, doc):
        return doc["query"]

    def doc_to_target(self, doc):
        return " " + doc["answer"]

    def fewshot_docs(self, k):
        assert k >= 5, "Maximum 5 few shot examples."
        return self._load_docs('dev')[:k]

    def fewshot_description(self):
        subject = self.subject.replace("_", " ")
        return f"The following are multiple choice questions (with answers) about {subject}.\n\n"

    def fewshot_context(self, doc, num_fewshot, provide_description):
        raw_description = self.fewshot_description()
        description = raw_description if provide_description else ""

        if num_fewshot == 0:
            labeled_examples = ""
        else:
            # TODO: crop if over max_len
            labeled_examples = "\n\n".join(
                [self.doc_to_text(doc) + self.doc_to_target(doc) for doc in self.fewshot_docs(k=num_fewshot)]
            ) + "\n\n"

        example = self.doc_to_text(doc)
        return description + labeled_examples + example