Commit baadb2da authored by Kevin Wang's avatar Kevin Wang
Browse files

use openai's gpt3 datafiles instead of generating equations

parent 183f4721
import abc import abc
import json import json
import random import os
from collections import namedtuple from collections import namedtuple
from lm_eval.base import Dataset, mean, rf from lm_eval.base import Dataset, mean, rf
from best_download import download_file
ArithmeticDoc = namedtuple('ArithmeticDoc', ['question_text', 'answer_text']) ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion'])
class Arithmetic(Dataset): class Arithmetic(Dataset):
def __init__(self, number_of_problems=2000): directory = 'data/arithmetic/'
def __init__(self):
super().__init__() super().__init__()
self.problems = self.generate_problems(number_of_problems) self.set_docs()
def download(self):
file_name, checksum = self.get_file_download_info()
url = 'https://raw.githubusercontent.com/openai/gpt-3/master/data/' + file_name
if not os.path.exists(self.directory):
os.makedirs(self.directory)
download_file(url, self.directory+file_name, checksum)
@abc.abstractmethod @abc.abstractmethod
def generate_problems(self, number_of_problems): def get_file_download_info(self):
"""returns a tuple of (file_name, checksum)"""
pass pass
def set_docs(self):
file_name, _ = self.get_file_download_info()
jsons = open(self.directory+file_name, 'r')
self._docs = [self.load_doc(json.loads(line)) for line in jsons]
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -25,23 +41,25 @@ class Arithmetic(Dataset): ...@@ -25,23 +41,25 @@ class Arithmetic(Dataset):
return False return False
def training_docs(self): def training_docs(self):
return self.problems return self._docs
def validation_docs(self): def validation_docs(self):
return self.generate_problems(50) return self._docs[:100]
def test_docs(self): def test_docs(self):
return NotImplemented return NotImplemented
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"Q: What is {doc.question_text}?\nA: " return doc.context
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc.answer_text return doc.completion
def load_doc(self, doc_json):
return ArithmeticDoc(context=doc_json['context'], completion=doc_json['completion'])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll, is_prediction = rf.loglikelihood(ctx, ' '+doc.answer_text) ll, is_prediction = rf.loglikelihood(ctx, doc.completion)
# not sure what the difference between the two objects returned by rf.loglikehood are here
return is_prediction return is_prediction
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -62,92 +80,41 @@ class Arithmetic(Dataset): ...@@ -62,92 +80,41 @@ class Arithmetic(Dataset):
class Arithmetic2DPlus(Arithmetic): class Arithmetic2DPlus(Arithmetic):
def generate_problems(self, number_of_problems): def get_file_download_info(self):
l = [] return 'two_digit_addition.jsonl', '75a54b7a3db3b23369df74fe440c23025f3d3c51f664300bd3d56632b2617b3d'
for i in range(number_of_problems):
x,y = random.randint(0,99), random.randint(0,99)
a = x+y
l.append(ArithmeticDoc(question_text=f'{x}+{y}', answer_text=f'{a}'))
return l
class Arithmetic2DMinus(Arithmetic): class Arithmetic2DMinus(Arithmetic):
def generate_problems(self, number_of_problems): def get_file_download_info(self):
l = [] return 'two_digit_subtraction.jsonl', 'da956066ff108c00b341d360567472784f5fd872d6465071b44a14291205bc03'
for i in range(number_of_problems):
x,y = random.randint(0,99), random.randint(0,99)
a = x-y
l.append(ArithmeticDoc(question_text=f'{x}-{y}', answer_text=f'{a}'))
return l
class Arithmetic3DPlus(Arithmetic): class Arithmetic3DPlus(Arithmetic):
def generate_problems(self, number_of_problems): def get_file_download_info(self):
l = [] return 'three_digit_addition.jsonl', '124865e30efd2abfbc1855dd34c218fc02d32d780ace970ab9b4ea3fa74c798b'
for i in range(number_of_problems):
x,y = random.randint(0,999), random.randint(0,999)
a = x+y
l.append(ArithmeticDoc(question_text=f'{x}+{y}', answer_text=f'{a}'))
return l
class Arithmetic3DMinus(Arithmetic): class Arithmetic3DMinus(Arithmetic):
def generate_problems(self, number_of_problems): def get_file_download_info(self):
l = [] return 'three_digit_subtraction.jsonl', '7fc6aaedcb0e2bd17c398dd4147c5585b1e608278a8e98b914e69656707d6a29'
for i in range(number_of_problems):
x,y = random.randint(0,999), random.randint(0,999)
a = x-y
l.append(ArithmeticDoc(question_text=f'{x}-{y}', answer_text=f'{a}'))
return l
class Arithmetic4DPlus(Arithmetic): class Arithmetic4DPlus(Arithmetic):
def generate_problems(self, number_of_problems): def get_file_download_info(self):
l = [] return 'four_digit_addition.jsonl', '459c6f75baa2e8d7cf50bdd07db6d0ca9133a6b137d95d09267db85b6e07f391'
for i in range(number_of_problems):
x,y = random.randint(0,9999), random.randint(0,9999)
a = x+y
l.append(ArithmeticDoc(question_text=f'{x}+{y}', answer_text=f'{a}'))
return l
class Arithmetic4DMinus(Arithmetic): class Arithmetic4DMinus(Arithmetic):
def generate_problems(self, number_of_problems): def get_file_download_info(self):
l = [] return 'four_digit_subtraction.jsonl', '0c47db40a10c052ef0cf732a9ef2edaa53d66377d43eb47a9c382d33a8af7102'
for i in range(number_of_problems):
x,y = random.randint(0,9999), random.randint(0,9999)
a = x-y
l.append(ArithmeticDoc(question_text=f'{x}-{y}', answer_text=f'{a}'))
return l
class Arithmetic5DPlus(Arithmetic): class Arithmetic5DPlus(Arithmetic):
def generate_problems(self, number_of_problems): def get_file_download_info(self):
l = [] return 'five_digit_addition.jsonl', '30ada42efe315b958c6e9649274005d3b720e50298e92c3a2d321f8996e58f54'
for i in range(number_of_problems):
x,y = random.randint(0,99999), random.randint(0,99999)
a = x+y
l.append(ArithmeticDoc(question_text=f'{x}+{y}', answer_text=f'{a}'))
return l
class Arithmetic5DMinus(Arithmetic): class Arithmetic5DMinus(Arithmetic):
def generate_problems(self, number_of_problems): def get_file_download_info(self):
l = [] return 'five_digit_subtraction.jsonl', '8b98ccfc943cbf9193bcf1984954aa0b1a4527016072d972a2b055cc1482ca3c'
for i in range(number_of_problems):
x,y = random.randint(0,99999), random.randint(0,99999)
a = x-y
l.append(ArithmeticDoc(question_text=f'{x}-{y}', answer_text=f'{a}'))
return l
class Arithmetic2DMultiplication(Arithmetic): class Arithmetic2DMultiplication(Arithmetic):
def generate_problems(self, number_of_problems): def get_file_download_info(self):
l = [] return 'two_digit_multiplication.jsonl', '5613d1d1cc3b2c03edc1990252247d34c10ec82944b2cdeb19e71b00f237f431'
for i in range(number_of_problems):
x,y = random.randint(0,99), random.randint(0,99)
a = x*y
l.append(ArithmeticDoc(question_text=f'{x}*{y}', answer_text=f'{a}'))
return l
class Arithmetic1DComposite(Arithmetic): class Arithmetic1DComposite(Arithmetic):
def generate_problems(self, number_of_problems): def get_file_download_info(self):
l = [] return 'single_digit_three_ops.jsonl', '08b34e3272a8ff1d4932d63f251519d14c485c38d582366e1e323d0b859c3925'
for i in range(number_of_problems):
x,y,z = random.randint(0,9), random.randint(0,9), random.randint(0,9)
op1, op2 = random.choice('-+*'), random.choice('-+*')
to_eval = f'{x}{op1}({y}{op2}{z})'
l.append(ArithmeticDoc(question_text=to_eval, answer_text=str(eval(to_eval))))
return l
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment