"include/ck/utility/amd_buffer_addressing.hpp" did not exist on "a7a758d8ceef978d14de97df4a3c67fa97a20a13"
Commit 410ca65f authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Split `test` and `train` doc loading and address minor fix-ups

parent 6d997710
...@@ -59,13 +59,13 @@ The goal of this project is to build a set of tools for evaluating LMs on typica ...@@ -59,13 +59,13 @@ The goal of this project is to build a set of tools for evaluating LMs on typica
|ethics_utilitarianism_original|✓ |✓ |✓ |acc | |ethics_utilitarianism_original|✓ |✓ |✓ |acc |
|ethics_utilitarianism |✓ |✓ |✓ |acc | |ethics_utilitarianism |✓ |✓ |✓ |acc |
|ethics_virtue |✓ |✓ |✓ |acc, em | |ethics_virtue |✓ |✓ |✓ |acc, em |
|math_algebra |✓ | |✓ |acc | |math_algebra |✓ | |✓ |acc |
|math_counting_and_prob |✓ | |✓ |acc | |math_counting_and_prob |✓ | |✓ |acc |
|math_geometry |✓ | |✓ |acc | |math_geometry |✓ | |✓ |acc |
|math_intermediate_algebra |✓ | |✓ |acc | |math_intermediate_algebra |✓ | |✓ |acc |
|math_num_theory |✓ | |✓ |acc | |math_num_theory |✓ | |✓ |acc |
|math_prealgebra |✓ | |✓ |acc | |math_prealgebra |✓ | |✓ |acc |
|math_precalc |✓ | |✓ |acc | |math_precalc |✓ | |✓ |acc |
|arithmetic_2da | |✓ | |acc | |arithmetic_2da | |✓ | |acc |
|arithmetic_2ds | |✓ | |acc | |arithmetic_2ds | |✓ | |acc |
|arithmetic_3da | |✓ | |acc | |arithmetic_3da | |✓ | |acc |
......
import json import json
import os import random
from lm_eval.utils import sh from lm_eval.utils import sh
from lm_eval.metrics import mean from lm_eval.metrics import mean
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from pathlib import Path
import abc import abc
class Math(Task): class Math(Task):
directory = 'data/MATH' """
This dataset is based on the following paper:
https://arxiv.org/abs/2103.03874
"""
DATASET_PATH = Path('data/MATH')
def download(self): def download(self):
if not os.path.exists('data/'): if not self.DATASET_PATH.exists():
sh("mkdir data") sh(f"""
if not os.path.exists('data/MATH/'): mkdir -p {self.DATASET_PATH}
sh("wget https://people.eecs.berkeley.edu/~hendrycks/MATH.tar.gz -P data/") wget https://people.eecs.berkeley.edu/~hendrycks/MATH.tar.gz -P data/
sh("tar -xvf data/MATH.tar.gz -C data/") tar -xvf {self.DATASET_PATH}.tar.gz -C data/
sh("rm data/MATH.tar.gz") rm {self.DATASET_PATH}.tar.gz
self.set_docs() """)
@abc.abstractmethod @abc.abstractmethod
def get_file_info(self): def get_file_info(self):
"""returns directory name""" """returns directory name"""
pass pass
def set_docs(self):
self._training_docs = []
self._testing_docs = []
dir_name = self.get_file_info()
train_path = os.path.join('data/MATH/train',dir_name).replace("\\", "/")
test_path = os.path.join('data/MATH/test',dir_name).replace("\\", "/")
for filename in os.listdir(train_path):
with open(os.path.join(train_path, filename).replace("\\", "/")) as f:
self._training_docs.append(json.load(f))
for filename in os.listdir(test_path):
with open(os.path.join(test_path, filename).replace("\\", "/")) as f:
self._testing_docs.append(json.load(f))
for doc in self._testing_docs:
doc["answer"] = self.remove_boxed(self.last_boxed_only_string(doc["solution"]))
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -49,13 +39,25 @@ class Math(Task): ...@@ -49,13 +39,25 @@ class Math(Task):
return True return True
def training_docs(self): def training_docs(self):
return self._training_docs path = self.DATASET_PATH / "train" / self.get_file_info()
for file in path.iterdir():
with open(file) as f:
yield json.load(f)
def validation_docs(self): def validation_docs(self):
return NotImplemented return NotImplemented
def test_docs(self): def test_docs(self):
return self._testing_docs path = self.DATASET_PATH / "test" / self.get_file_info()
for file in path.iterdir():
with open(file) as f:
doc = json.load(f)
doc["answer"] = self.remove_boxed(
self.last_boxed_only_string(doc["solution"]))
yield doc
def fewshot_description(self):
return "Given a mathematics problem, determine the answer. Simplify your answer as much as possible."
def fewshot_examples(self, k): def fewshot_examples(self, k):
assert k <= 8, "There are only 8 possible shots for this task." assert k <= 8, "There are only 8 possible shots for this task."
...@@ -64,26 +66,21 @@ class Math(Task): ...@@ -64,26 +66,21 @@ class Math(Task):
{"problem": "In how many ways can 4 books be selected from a shelf of 6 books if the order in which the books are selected does not matter?", "answer": "$15$"}, {"problem": "In how many ways can 4 books be selected from a shelf of 6 books if the order in which the books are selected does not matter?", "answer": "$15$"},
{"problem": "Find the distance between the points $(2,1,-4)$ and $(5,8,-3).$", "answer": "$\sqrt{59}$"}, {"problem": "Find the distance between the points $(2,1,-4)$ and $(5,8,-3).$", "answer": "$\sqrt{59}$"},
{"problem": "The faces of an octahedral die are labeled with digits $1$ through $8$. What is the probability, expressed as a common fraction, of rolling a sum of $15$ with a pair of such octahedral dice?", "answer": "$\\frac{1}{32}$"}, {"problem": "The faces of an octahedral die are labeled with digits $1$ through $8$. What is the probability, expressed as a common fraction, of rolling a sum of $15$ with a pair of such octahedral dice?", "answer": "$\\frac{1}{32}$"},
{"problem": "The first three terms of an arithmetic sequence are 1, 10 and 19, respectively. What is the value of the 21st term?" , "answer": "$181$"}, {"problem": "The first three terms of an arithmetic sequence are 1, 10 and 19, respectively. What is the value of the 21st term?", "answer": "$181$"},
{"problem": "Calculate $6 \\cdot 8\\frac{1}{3}", "answer": "$50$"}, {"problem": "Calculate $6 \\cdot 8\\frac{1}{3}", "answer": "$50$"},
{"problem": "When the binary number $100101110010_2$ is divided by 4, what is the remainder (give your answer in base 10)?", "answer": "$2$"}, {"problem": "When the binary number $100101110010_2$ is divided by 4, what is the remainder (give your answer in base 10)?", "answer": "$2$"},
{"problem": "How many zeros are at the end of the product 25 $\\times$ 240?", "answer": "$3$"} {"problem": "How many zeros are at the end of the product 25 $\\times$ 240?", "answer": "$3$"}
] ]
return prompts[:k] return random.sample(prompts, k)
def doc_to_text(self, doc):
return "\n" + "###" + "\n" + "Problem: " + doc["problem"] + "\n" + "Answer:"
def doc_to_text(self, doc):
return "Problem: " + doc["problem"] + "\nAnswer:"
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc["answer"] return " " + doc["answer"]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
to_send = self.doc_to_text(doc) return rf.greedy_until(ctx, ["\n"])
answer = self.doc_to_target(doc)
ll = rf.greedy_until(ctx + to_send, "\n")
return ll
def process_results(self, doc, results): def process_results(self, doc, results):
retval = 0 retval = 0
...@@ -109,22 +106,6 @@ class Math(Task): ...@@ -109,22 +106,6 @@ class Math(Task):
'acc': True 'acc': True
} }
def fewshot_description(self):
return "Given a mathematics problem, determine the answer. Simplify your answer as much as possible." + "\n"
def fewshot_context(self, doc, num_fewshot, provide_description):
description = self.fewshot_description()
if num_fewshot == 0:
labeled_examples = ""
else:
labeled_examples = "\n\n".join(
[self.doc_to_text(doc) + self.doc_to_target(doc) + "\n" + "###" + "\n" for doc in self.fewshot_examples(k=num_fewshot)]
) + "\n\n"
return description + labeled_examples
def is_equiv(self, str1, str2, verbose=False): def is_equiv(self, str1, str2, verbose=False):
if str1 is None and str2 is None: if str1 is None and str2 is None:
print("WARNING: Both None") print("WARNING: Both None")
...@@ -147,7 +128,7 @@ class Math(Task): ...@@ -147,7 +128,7 @@ class Math(Task):
assert s[:len(left)] == left assert s[:len(left)] == left
assert s[-1] == "}" assert s[-1] == "}"
return s[len(left):-1] return s[len(left):-1]
except: except AssertionError:
return None return None
def last_boxed_only_string(self, string): def last_boxed_only_string(self, string):
...@@ -169,12 +150,12 @@ class Math(Task): ...@@ -169,12 +150,12 @@ class Math(Task):
right_brace_idx = i right_brace_idx = i
break break
i += 1 i += 1
if right_brace_idx == None: if right_brace_idx is None:
retval = None retval = None
else: else:
retval = string[idx:right_brace_idx + 1] retval = string[idx:right_brace_idx + 1]
return retval return retval
def fix_fracs(self, string): def fix_fracs(self, string):
...@@ -189,7 +170,7 @@ class Math(Task): ...@@ -189,7 +170,7 @@ class Math(Task):
else: else:
try: try:
assert len(substr) >= 2 assert len(substr) >= 2
except: except AssertionError:
return string return string
a = substr[0] a = substr[0]
b = substr[1] b = substr[1]
...@@ -219,7 +200,7 @@ class Math(Task): ...@@ -219,7 +200,7 @@ class Math(Task):
assert string == "{}/{}".format(a, b) assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string return new_string
except: except AssertionError:
return string return string
def remove_right_units(self, string): def remove_right_units(self, string):
...@@ -235,7 +216,7 @@ class Math(Task): ...@@ -235,7 +216,7 @@ class Math(Task):
if "\\sqrt" not in string: if "\\sqrt" not in string:
return string return string
splits = string.split("\\sqrt") splits = string.split("\\sqrt")
new_string = splits[0] new_string = splits[0]
for split in splits[1:]: for split in splits[1:]:
if split[0] != "{": if split[0] != "{":
a = split[0] a = split[0]
...@@ -250,7 +231,7 @@ class Math(Task): ...@@ -250,7 +231,7 @@ class Math(Task):
return False return False
def strip_string(self, string): def strip_string(self, string):
# linebreaks # linebreaks
string = string.replace("\n", "") string = string.replace("\n", "")
# remove inverse spaces # remove inverse spaces
...@@ -266,18 +247,17 @@ class Math(Task): ...@@ -266,18 +247,17 @@ class Math(Task):
# remove \left and \right # remove \left and \right
string = string.replace("\\left", "") string = string.replace("\\left", "")
string = string.replace("\\right", "") string = string.replace("\\right", "")
# Remove circ (degrees) # Remove circ (degrees)
string = string.replace("^{\\circ}", "") string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "") string = string.replace("^\\circ", "")
# remove dollar signs # remove dollar signs
string = string.replace("\\$", "") string = string.replace("\\$", "")
# remove units (on the right) # remove units (on the right)
string = self.remove_right_units(string) string = self.remove_right_units(string)
# remove percentage # remove percentage
string = string.replace("\\%", "") string = string.replace("\\%", "")
string = string.replace("\%", "") string = string.replace("\%", "")
...@@ -314,30 +294,37 @@ class Math(Task): ...@@ -314,30 +294,37 @@ class Math(Task):
return string return string
class MathAlgebra(Math): class MathAlgebra(Math):
def get_file_info(self): def get_file_info(self):
return 'algebra' return 'algebra'
class MathCountingAndProbability(Math): class MathCountingAndProbability(Math):
def get_file_info(self): def get_file_info(self):
return 'counting_and_probability' return 'counting_and_probability'
class MathGeometry(Math): class MathGeometry(Math):
def get_file_info(self): def get_file_info(self):
return 'geometry' return 'geometry'
class MathIntermediateAlgebra(Math): class MathIntermediateAlgebra(Math):
def get_file_info(self): def get_file_info(self):
return 'intermediate_algebra' return 'intermediate_algebra'
class MathNumberTheory(Math): class MathNumberTheory(Math):
def get_file_info(self): def get_file_info(self):
return 'number_theory' return 'number_theory'
class MathPrealgebra(Math): class MathPrealgebra(Math):
def get_file_info(self): def get_file_info(self):
return 'prealgebra' return 'prealgebra'
class MathPrecalculus(Math): class MathPrecalculus(Math):
def get_file_info(self): def get_file_info(self):
return 'precalculus' return 'precalculus'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment