Commit 806b022b authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Refactor and simple formatting

parent 13b97626
from lm_eval.base import MultipleChoiceTask
import os
import csv
import numpy as np
import random
from lm_eval.base import MultipleChoiceTask
from ..utils import sh
SUBJECTS = ['abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology',
'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics',
'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics',
'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science',
'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics',
'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics',
'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence',
'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes',
'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine',
from pathlib import Path
SUBJECTS = ['abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology',
'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics',
'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics',
'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science',
'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics',
'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics',
'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence',
'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes',
'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine',
'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions']
CHOICES = ['A','B','C','D']
def create_all_tasks():
"""Creates a dictionary of tasks from a list of subjects
......@@ -26,22 +25,23 @@ def create_all_tasks():
f"hendrycksTest-{sub}": create_task(sub) for sub in SUBJECTS
}
def create_task(subject):
class HendrycksTest(GeneralHendrycksTest):
def __init__(self):
super().__init__(subject)
return HendrycksTest
class GeneralHendrycksTest(MultipleChoiceTask):
DATASET_PATH = Path("data/hendrycksTest/")
def __init__(self, subject):
self.subject = subject
super().__init__()
def download(self):
self.data_dir = "data/hendrycksTest/"
if not os.path.exists(self.data_dir):
if not self.DATASET_PATH.exists():
sh("""
mkdir -p data
wget https://people.eecs.berkeley.edu/~hendrycks/data.tar -P data/
......@@ -59,79 +59,41 @@ class GeneralHendrycksTest(MultipleChoiceTask):
def has_test_docs(self):
return True
def _load_docs(self, filename):
def _convert_standard(self, doc):
return {
"query": "Question: " + doc[0] + "\nAnswer:",
"choices": doc[1:5],
"gold": ['A', 'B', 'C', 'D'].index(doc[5])
}
def _load_docs(self, filename):
reader = csv.reader(open(filename, 'r'), quotechar='"', delimiter=',')
return (self._convert_standard(doc) for doc in reader)
docs = []
for row in reader:
doc = {
"query": self._format_example(row),
"choices": CHOICES,
"gold": CHOICES.index(row[5])
}
docs.append(doc)
return docs
def _format_example(self, row):
"""
<prompt>
A. <choice1>
B. <choice2>
C. <choice3>
D. <choice4>
Answer:
"""
prompt = row[0]
for j in range(4):
prompt += "\n{}. {}".format(CHOICES[j], row[j+1])
prompt += "\nAnswer:"
return prompt
def training_docs(self):
docs = []
# Use all files in the auxiliary_train, dev, val directories
# auxiliary_train includes some UnifiedQA MC tasks
docs = []
for train_dir in ["auxiliary_train", "dev", "val"]:
train_dir = os.path.join(self.data_dir, train_dir)
for f in os.listdir(train_dir):
filename = os.path.join(train_dir, f)
docs.extend(self._load_docs(filename))
for f in (self.DATASET_PATH / train_dir).iterdir():
docs.extend(self._load_docs(f))
return docs
def validation_docs(self):
raise NotImplementedError
def test_docs(self):
filename = os.path.join(self.data_dir, "test", self.subject + f"_test.csv")
filename = self.DATASET_PATH / "test" / f"{self.subject}_test.csv"
return self._load_docs(filename)
def doc_to_text(self, doc):
return doc["query"]
def doc_to_target(self, doc):
return " " + doc["answer"]
def fewshot_docs(self, k):
assert k >= 5, "Maximum 5 few shot examples."
filename = os.path.join(self.data_dir, "dev", self.subject + f"_dev.csv")
return self._load_docs(filename)[:k]
def fewshot_examples(self, k):
assert k <= 5, "Maximum 5 few shot examples."
filename = self.DATASET_PATH / "dev" / f"{self.subject}_dev.csv"
return random.sample(list(self._load_docs(filename)), k)
def fewshot_description(self):
subject = self.subject.replace("_", " ")
return f"The following are multiple choice questions (with answers) about {subject}.\n\n"
def fewshot_context(self, doc, num_fewshot, provide_description):
raw_description = self.fewshot_description()
description = raw_description if provide_description else ""
if num_fewshot == 0:
labeled_examples = ""
else:
# TODO: crop if over max_len
labeled_examples = "\n\n".join(
[self.doc_to_text(doc) + self.doc_to_target(doc) for doc in self.fewshot_docs(k=num_fewshot)]
) + "\n\n"
example = self.doc_to_text(doc)
return description + labeled_examples + example
return f"The following are multiple choice questions (with answers) about {subject}."
def doc_to_text(self, doc):
return doc["query"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment