Commit baadb2da authored by Kevin Wang's avatar Kevin Wang
Browse files

use openai's gpt3 datafiles instead of generating equations

parent 183f4721
import abc
import json
import random
import os
from collections import namedtuple
from lm_eval.base import Dataset, mean, rf
from best_download import download_file
ArithmeticDoc = namedtuple('ArithmeticDoc', ['question_text', 'answer_text'])
ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion'])
class Arithmetic(Dataset):
def __init__(self, number_of_problems=2000):
directory = 'data/arithmetic/'
def __init__(self):
super().__init__()
self.problems = self.generate_problems(number_of_problems)
self.set_docs()
def download(self):
file_name, checksum = self.get_file_download_info()
url = 'https://raw.githubusercontent.com/openai/gpt-3/master/data/' + file_name
if not os.path.exists(self.directory):
os.makedirs(self.directory)
download_file(url, self.directory+file_name, checksum)
@abc.abstractmethod
def generate_problems(self, number_of_problems):
def get_file_download_info(self):
"""returns a tuple of (file_name, checksum)"""
pass
def set_docs(self):
file_name, _ = self.get_file_download_info()
jsons = open(self.directory+file_name, 'r')
self._docs = [self.load_doc(json.loads(line)) for line in jsons]
def has_training_docs(self):
return True
......@@ -25,23 +41,25 @@ class Arithmetic(Dataset):
return False
def training_docs(self):
return self.problems
return self._docs
def validation_docs(self):
return self.generate_problems(50)
return self._docs[:100]
def test_docs(self):
return NotImplemented
def doc_to_text(self, doc):
return f"Q: What is {doc.question_text}?\nA: "
return doc.context
def doc_to_target(self, doc):
return doc.answer_text
return doc.completion
def load_doc(self, doc_json):
return ArithmeticDoc(context=doc_json['context'], completion=doc_json['completion'])
def construct_requests(self, doc, ctx):
ll, is_prediction = rf.loglikelihood(ctx, ' '+doc.answer_text)
# not sure what the difference between the two objects returned by rf.loglikehood are here
ll, is_prediction = rf.loglikelihood(ctx, doc.completion)
return is_prediction
def process_results(self, doc, results):
......@@ -62,92 +80,41 @@ class Arithmetic(Dataset):
class Arithmetic2DPlus(Arithmetic):
def generate_problems(self, number_of_problems):
l = []
for i in range(number_of_problems):
x,y = random.randint(0,99), random.randint(0,99)
a = x+y
l.append(ArithmeticDoc(question_text=f'{x}+{y}', answer_text=f'{a}'))
return l
def get_file_download_info(self):
return 'two_digit_addition.jsonl', '75a54b7a3db3b23369df74fe440c23025f3d3c51f664300bd3d56632b2617b3d'
class Arithmetic2DMinus(Arithmetic):
def generate_problems(self, number_of_problems):
l = []
for i in range(number_of_problems):
x,y = random.randint(0,99), random.randint(0,99)
a = x-y
l.append(ArithmeticDoc(question_text=f'{x}-{y}', answer_text=f'{a}'))
return l
def get_file_download_info(self):
return 'two_digit_subtraction.jsonl', 'da956066ff108c00b341d360567472784f5fd872d6465071b44a14291205bc03'
class Arithmetic3DPlus(Arithmetic):
def generate_problems(self, number_of_problems):
l = []
for i in range(number_of_problems):
x,y = random.randint(0,999), random.randint(0,999)
a = x+y
l.append(ArithmeticDoc(question_text=f'{x}+{y}', answer_text=f'{a}'))
return l
def get_file_download_info(self):
return 'three_digit_addition.jsonl', '124865e30efd2abfbc1855dd34c218fc02d32d780ace970ab9b4ea3fa74c798b'
class Arithmetic3DMinus(Arithmetic):
def generate_problems(self, number_of_problems):
l = []
for i in range(number_of_problems):
x,y = random.randint(0,999), random.randint(0,999)
a = x-y
l.append(ArithmeticDoc(question_text=f'{x}-{y}', answer_text=f'{a}'))
return l
def get_file_download_info(self):
return 'three_digit_subtraction.jsonl', '7fc6aaedcb0e2bd17c398dd4147c5585b1e608278a8e98b914e69656707d6a29'
class Arithmetic4DPlus(Arithmetic):
def generate_problems(self, number_of_problems):
l = []
for i in range(number_of_problems):
x,y = random.randint(0,9999), random.randint(0,9999)
a = x+y
l.append(ArithmeticDoc(question_text=f'{x}+{y}', answer_text=f'{a}'))
return l
def get_file_download_info(self):
return 'four_digit_addition.jsonl', '459c6f75baa2e8d7cf50bdd07db6d0ca9133a6b137d95d09267db85b6e07f391'
class Arithmetic4DMinus(Arithmetic):
def generate_problems(self, number_of_problems):
l = []
for i in range(number_of_problems):
x,y = random.randint(0,9999), random.randint(0,9999)
a = x-y
l.append(ArithmeticDoc(question_text=f'{x}-{y}', answer_text=f'{a}'))
return l
def get_file_download_info(self):
return 'four_digit_subtraction.jsonl', '0c47db40a10c052ef0cf732a9ef2edaa53d66377d43eb47a9c382d33a8af7102'
class Arithmetic5DPlus(Arithmetic):
def generate_problems(self, number_of_problems):
l = []
for i in range(number_of_problems):
x,y = random.randint(0,99999), random.randint(0,99999)
a = x+y
l.append(ArithmeticDoc(question_text=f'{x}+{y}', answer_text=f'{a}'))
return l
def get_file_download_info(self):
return 'five_digit_addition.jsonl', '30ada42efe315b958c6e9649274005d3b720e50298e92c3a2d321f8996e58f54'
class Arithmetic5DMinus(Arithmetic):
def generate_problems(self, number_of_problems):
l = []
for i in range(number_of_problems):
x,y = random.randint(0,99999), random.randint(0,99999)
a = x-y
l.append(ArithmeticDoc(question_text=f'{x}-{y}', answer_text=f'{a}'))
return l
def get_file_download_info(self):
return 'five_digit_subtraction.jsonl', '8b98ccfc943cbf9193bcf1984954aa0b1a4527016072d972a2b055cc1482ca3c'
class Arithmetic2DMultiplication(Arithmetic):
def generate_problems(self, number_of_problems):
l = []
for i in range(number_of_problems):
x,y = random.randint(0,99), random.randint(0,99)
a = x*y
l.append(ArithmeticDoc(question_text=f'{x}*{y}', answer_text=f'{a}'))
return l
def get_file_download_info(self):
return 'two_digit_multiplication.jsonl', '5613d1d1cc3b2c03edc1990252247d34c10ec82944b2cdeb19e71b00f237f431'
class Arithmetic1DComposite(Arithmetic):
def generate_problems(self, number_of_problems):
l = []
for i in range(number_of_problems):
x,y,z = random.randint(0,9), random.randint(0,9), random.randint(0,9)
op1, op2 = random.choice('-+*'), random.choice('-+*')
to_eval = f'{x}{op1}({y}{op2}{z})'
l.append(ArithmeticDoc(question_text=to_eval, answer_text=str(eval(to_eval))))
return l
def get_file_download_info(self):
return 'single_digit_three_ops.jsonl', '08b34e3272a8ff1d4932d63f251519d14c485c38d582366e1e323d0b859c3925'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment