"vscode:/vscode.git/clone" did not exist on "9b22bdd9ff44bf39bd35ce9721fe5020fa28006b"
Commit 7989168d authored by Andy Zou's avatar Andy Zou
Browse files

added training

parent 8888a6be
from lm_eval.base import Task from lm_eval.base import MultipleChoiceTask
import os import os
import csv import csv
import numpy as np import numpy as np
...@@ -32,7 +32,7 @@ def create_task(subject): ...@@ -32,7 +32,7 @@ def create_task(subject):
super().__init__(subject) super().__init__(subject)
return HendrycksTest return HendrycksTest
class GeneralHendrycksTest(Task): class GeneralHendrycksTest(MultipleChoiceTask):
def __init__(self, subject): def __init__(self, subject):
self.subject = subject self.subject = subject
...@@ -44,24 +44,22 @@ class GeneralHendrycksTest(Task): ...@@ -44,24 +44,22 @@ class GeneralHendrycksTest(Task):
if not os.path.exists(self.data_dir): if not os.path.exists(self.data_dir):
sh(""" sh("""
mkdir -p data mkdir -p data
wget https://people.eecs.berkeley.edu/~hendrycks/data.tar -P data/ wget https://people.eecs.berkeley.edu/~hendrycks/hendrycksTest.tar.gz -P data/
tar -xf data/data.tar -C data/ tar -xf data/hendrycksTest.tar.gz -C data/
rm data/data.tar rm data/hendrycksTest.tar.gz
mv data/data data/hendrycksTest
""") """)
def has_training_docs(self): def has_training_docs(self):
return False return True
def has_validation_docs(self): def has_validation_docs(self):
return True return False
def has_test_docs(self): def has_test_docs(self):
return True return True
def _load_docs(self, split): def _load_docs(self, filename):
filename = os.path.join(self.data_dir, split, self.subject + f"_{split}.csv")
reader = csv.reader(open(filename, 'r'), quotechar='"', delimiter=',') reader = csv.reader(open(filename, 'r'), quotechar='"', delimiter=',')
docs = [] docs = []
...@@ -69,7 +67,7 @@ class GeneralHendrycksTest(Task): ...@@ -69,7 +67,7 @@ class GeneralHendrycksTest(Task):
doc = { doc = {
"query": self._format_example(row), "query": self._format_example(row),
"choices": CHOICES, "choices": CHOICES,
"gold": CHOICES.index(row[-1]) "gold": CHOICES.index(row[5])
} }
docs.append(doc) docs.append(doc)
return docs return docs
...@@ -84,20 +82,27 @@ class GeneralHendrycksTest(Task): ...@@ -84,20 +82,27 @@ class GeneralHendrycksTest(Task):
Answer: Answer:
""" """
prompt = row[0] prompt = row[0]
k = len(row) - 2 for j in range(4):
for j in range(k):
prompt += "\n{}. {}".format(CHOICES[j], row[j+1]) prompt += "\n{}. {}".format(CHOICES[j], row[j+1])
prompt += "\nAnswer:" prompt += "\nAnswer:"
return prompt return prompt
def training_docs(self): def training_docs(self):
raise NotImplementedError docs = []
# Use all files in the train, dev, val directories (including some UnifiedQA MC tasks)
for train_dir in ["train", "dev", "val"]:
train_dir = os.path.join(self.data_dir, train_dir)
for f in os.listdir(train_dir):
filename = os.path.join(train_dir, f)
docs.extend(self._load_docs(filename))
return docs
def validation_docs(self): def validation_docs(self):
return self._load_docs("val") raise NotImplementedError
def test_docs(self): def test_docs(self):
return self._load_docs("test") filename = os.path.join(self.data_dir, "test", self.subject + f"_test.csv")
return self._load_docs(filename)
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
...@@ -107,7 +112,8 @@ class GeneralHendrycksTest(Task): ...@@ -107,7 +112,8 @@ class GeneralHendrycksTest(Task):
def fewshot_docs(self, k): def fewshot_docs(self, k):
assert k >= 5, "Maximum 5 few shot examples." assert k >= 5, "Maximum 5 few shot examples."
return self._load_docs('dev')[:k] filename = os.path.join(self.data_dir, "dev", self.subject + f"_dev.csv")
return self._load_docs(filename)[:k]
def fewshot_description(self): def fewshot_description(self):
subject = self.subject.replace("_", " ") subject = self.subject.replace("_", " ")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment