"include/vscode:/vscode.git/clone" did not exist on "334361cbde76a2566fb215a64a6652205b0d2336"
Unverified Commit 693c19e2 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #87 from VitamintK/master

first attempt at adding arithmetic evaluations
parents 38892f0a baadb2da
...@@ -12,6 +12,7 @@ from . import openbookqa ...@@ -12,6 +12,7 @@ from . import openbookqa
from . import squad from . import squad
from . import naturalqs from . import naturalqs
from . import sat from . import sat
from . import arithmetic
TASK_REGISTRY = { TASK_REGISTRY = {
# GLUE # GLUE
...@@ -49,6 +50,18 @@ TASK_REGISTRY = { ...@@ -49,6 +50,18 @@ TASK_REGISTRY = {
"anli_r1": anli.ANLIRound1, "anli_r1": anli.ANLIRound1,
"anli_r2": anli.ANLIRound2, "anli_r2": anli.ANLIRound2,
"anli_r3": anli.ANLIRound3, "anli_r3": anli.ANLIRound3,
# arithmetic
"arithmetic_2da": arithmetic.Arithmetic2DPlus,
"arithmetic_2ds": arithmetic.Arithmetic2DMinus,
"arithmetic_3da": arithmetic.Arithmetic3DPlus,
"arithmetic_3ds": arithmetic.Arithmetic3DMinus,
"arithmetic_4da": arithmetic.Arithmetic4DPlus,
"arithmetic_4ds": arithmetic.Arithmetic4DMinus,
"arithmetic_5da": arithmetic.Arithmetic5DPlus,
"arithmetic_5ds": arithmetic.Arithmetic5DMinus,
"arithmetic_2dm": arithmetic.Arithmetic2DMultiplication,
"arithmetic_1dc": arithmetic.Arithmetic1DComposite,
} }
......
import abc
import json
import os
from collections import namedtuple
from lm_eval.base import Dataset, mean, rf
from best_download import download_file
ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion'])
class Arithmetic(Dataset):
directory = 'data/arithmetic/'
def __init__(self):
super().__init__()
self.set_docs()
def download(self):
file_name, checksum = self.get_file_download_info()
url = 'https://raw.githubusercontent.com/openai/gpt-3/master/data/' + file_name
if not os.path.exists(self.directory):
os.makedirs(self.directory)
download_file(url, self.directory+file_name, checksum)
@abc.abstractmethod
def get_file_download_info(self):
"""returns a tuple of (file_name, checksum)"""
pass
def set_docs(self):
file_name, _ = self.get_file_download_info()
jsons = open(self.directory+file_name, 'r')
self._docs = [self.load_doc(json.loads(line)) for line in jsons]
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
return self._docs
def validation_docs(self):
return self._docs[:100]
def test_docs(self):
return NotImplemented
def doc_to_text(self, doc):
return doc.context
def doc_to_target(self, doc):
return doc.completion
def load_doc(self, doc_json):
return ArithmeticDoc(context=doc_json['context'], completion=doc_json['completion'])
def construct_requests(self, doc, ctx):
ll, is_prediction = rf.loglikelihood(ctx, doc.completion)
return is_prediction
def process_results(self, doc, results):
ll, is_prediction = results
return {
"acc": is_prediction
}
def aggregation(self):
return {
"acc": mean,
}
def higher_is_better(self):
return {
"acc": True
}
class Arithmetic2DPlus(Arithmetic):
def get_file_download_info(self):
return 'two_digit_addition.jsonl', '75a54b7a3db3b23369df74fe440c23025f3d3c51f664300bd3d56632b2617b3d'
class Arithmetic2DMinus(Arithmetic):
def get_file_download_info(self):
return 'two_digit_subtraction.jsonl', 'da956066ff108c00b341d360567472784f5fd872d6465071b44a14291205bc03'
class Arithmetic3DPlus(Arithmetic):
def get_file_download_info(self):
return 'three_digit_addition.jsonl', '124865e30efd2abfbc1855dd34c218fc02d32d780ace970ab9b4ea3fa74c798b'
class Arithmetic3DMinus(Arithmetic):
def get_file_download_info(self):
return 'three_digit_subtraction.jsonl', '7fc6aaedcb0e2bd17c398dd4147c5585b1e608278a8e98b914e69656707d6a29'
class Arithmetic4DPlus(Arithmetic):
def get_file_download_info(self):
return 'four_digit_addition.jsonl', '459c6f75baa2e8d7cf50bdd07db6d0ca9133a6b137d95d09267db85b6e07f391'
class Arithmetic4DMinus(Arithmetic):
def get_file_download_info(self):
return 'four_digit_subtraction.jsonl', '0c47db40a10c052ef0cf732a9ef2edaa53d66377d43eb47a9c382d33a8af7102'
class Arithmetic5DPlus(Arithmetic):
def get_file_download_info(self):
return 'five_digit_addition.jsonl', '30ada42efe315b958c6e9649274005d3b720e50298e92c3a2d321f8996e58f54'
class Arithmetic5DMinus(Arithmetic):
def get_file_download_info(self):
return 'five_digit_subtraction.jsonl', '8b98ccfc943cbf9193bcf1984954aa0b1a4527016072d972a2b055cc1482ca3c'
class Arithmetic2DMultiplication(Arithmetic):
def get_file_download_info(self):
return 'two_digit_multiplication.jsonl', '5613d1d1cc3b2c03edc1990252247d34c10ec82944b2cdeb19e71b00f237f431'
class Arithmetic1DComposite(Arithmetic):
def get_file_download_info(self):
return 'single_digit_three_ops.jsonl', '08b34e3272a8ff1d4932d63f251519d14c485c38d582366e1e323d0b859c3925'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment