"vscode:/vscode.git/clone" did not exist on "7cfaadf93b5542bad0097d6b7ae99bc06cb9b6e0"
arithmetic.py 3.87 KB
Newer Older
1
2
import abc
import json
3
import os
4
5
from collections import namedtuple
from lm_eval.base import Dataset, mean, rf
6
from best_download import download_file
7

8
ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion'])
9
10

class Arithmetic(Dataset):
11
12
13
    directory = 'data/arithmetic/'

    def __init__(self):
14
        super().__init__()
15
16
17
18
19
20
21
22
        self.set_docs()

    def download(self):
        file_name, checksum = self.get_file_download_info()
        url = 'https://raw.githubusercontent.com/openai/gpt-3/master/data/' + file_name
        if not os.path.exists(self.directory):
            os.makedirs(self.directory)
        download_file(url, self.directory+file_name, checksum)
23
24

    @abc.abstractmethod
25
26
    def get_file_download_info(self):
        """returns a tuple of (file_name, checksum)"""
27
28
        pass

29
30
31
32
33
    def set_docs(self):
        file_name, _ = self.get_file_download_info()
        jsons = open(self.directory+file_name, 'r')
        self._docs = [self.load_doc(json.loads(line)) for line in jsons]

34
35
36
37
38
39
40
41
42
43
    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
44
        return self._docs
45
46

    def validation_docs(self):
47
        return self._docs[:100]
48
49
50
51
52

    def test_docs(self):
        return NotImplemented
    
    def doc_to_text(self, doc):
53
        return doc.context
54
55

    def doc_to_target(self, doc):
56
        return doc.completion
57

58
59
60
    def load_doc(self, doc_json):
        return ArithmeticDoc(context=doc_json['context'], completion=doc_json['completion'])
    
61
    def construct_requests(self, doc, ctx):
62
        ll, is_prediction = rf.loglikelihood(ctx, doc.completion)
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        return is_prediction

    def process_results(self, doc, results):
        ll, is_prediction = results
        return {
            "acc": is_prediction
        }

    def aggregation(self):
        return {
            "acc": mean,
        }

    def higher_is_better(self):
        return {
            "acc": True
        }


class Arithmetic2DPlus(Arithmetic):
83
84
    def get_file_download_info(self):
        return 'two_digit_addition.jsonl', '75a54b7a3db3b23369df74fe440c23025f3d3c51f664300bd3d56632b2617b3d'
85
86

class Arithmetic2DMinus(Arithmetic):
87
88
    def get_file_download_info(self):
        return 'two_digit_subtraction.jsonl', 'da956066ff108c00b341d360567472784f5fd872d6465071b44a14291205bc03'
89
90

class Arithmetic3DPlus(Arithmetic):
91
92
    def get_file_download_info(self):
        return 'three_digit_addition.jsonl', '124865e30efd2abfbc1855dd34c218fc02d32d780ace970ab9b4ea3fa74c798b'
93
94

class Arithmetic3DMinus(Arithmetic):
95
96
    def get_file_download_info(self):
        return 'three_digit_subtraction.jsonl', '7fc6aaedcb0e2bd17c398dd4147c5585b1e608278a8e98b914e69656707d6a29'
97
98

class Arithmetic4DPlus(Arithmetic):
99
100
    def get_file_download_info(self):
        return 'four_digit_addition.jsonl', '459c6f75baa2e8d7cf50bdd07db6d0ca9133a6b137d95d09267db85b6e07f391'
101
102

class Arithmetic4DMinus(Arithmetic):
103
104
    def get_file_download_info(self):
        return 'four_digit_subtraction.jsonl', '0c47db40a10c052ef0cf732a9ef2edaa53d66377d43eb47a9c382d33a8af7102'
105
106

class Arithmetic5DPlus(Arithmetic):
107
108
    def get_file_download_info(self):
        return 'five_digit_addition.jsonl', '30ada42efe315b958c6e9649274005d3b720e50298e92c3a2d321f8996e58f54'
109
110

class Arithmetic5DMinus(Arithmetic):
111
112
    def get_file_download_info(self):
        return 'five_digit_subtraction.jsonl', '8b98ccfc943cbf9193bcf1984954aa0b1a4527016072d972a2b055cc1482ca3c'
113
114

class Arithmetic2DMultiplication(Arithmetic):
115
116
    def get_file_download_info(self):
        return 'two_digit_multiplication.jsonl', '5613d1d1cc3b2c03edc1990252247d34c10ec82944b2cdeb19e71b00f237f431'
117
118

class Arithmetic1DComposite(Arithmetic):
119
120
   def get_file_download_info(self):
        return 'single_digit_three_ops.jsonl', '08b34e3272a8ff1d4932d63f251519d14c485c38d582366e1e323d0b859c3925'