arithmetic.py 4 KB
Newer Older
1
2
import abc
import json
3
import os
4
from collections import namedtuple
&'s avatar
& committed
5
6
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
7
from best_download import download_file
8

9
ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion'])
10

11

12
class Arithmetic(Task):
13
14
15
    directory = 'data/arithmetic/'

    def __init__(self):
16
        super().__init__()
17
18
19
20
21
22
23

    def download(self):
        file_name, checksum = self.get_file_download_info()
        url = 'https://raw.githubusercontent.com/openai/gpt-3/master/data/' + file_name
        if not os.path.exists(self.directory):
            os.makedirs(self.directory)
        download_file(url, self.directory+file_name, checksum)
24
        self.set_docs()
25
26

    @abc.abstractmethod
27
28
    def get_file_download_info(self):
        """returns a tuple of (file_name, checksum)"""
29
30
        pass

31
32
33
34
35
    def set_docs(self):
        file_name, _ = self.get_file_download_info()
        jsons = open(self.directory+file_name, 'r')
        self._docs = [self.load_doc(json.loads(line)) for line in jsons]

36
    def has_training_docs(self):
Leo Gao's avatar
Leo Gao committed
37
        return False
38
39
40
41
42
43
44
45

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
Leo Gao's avatar
Leo Gao committed
46
        return NotImplemented
47
48

    def validation_docs(self):
Leo Gao's avatar
Leo Gao committed
49
        return self._docs
50
51
52
53
54

    def test_docs(self):
        return NotImplemented
    
    def doc_to_text(self, doc):
55
        return doc.context
56
57

    def doc_to_target(self, doc):
58
        return doc.completion
59

60
    def load_doc(self, doc_json):
Leo Gao's avatar
Leo Gao committed
61
62
63
64
        return ArithmeticDoc(context=doc_json['context'].strip()
            .replace('\n\n', '\n')
            .replace('Q:', 'Question:')
            .replace('A:', 'Answer:'), completion=doc_json['completion'])
65
    
66
    def construct_requests(self, doc, ctx):
67
        ll, is_prediction = rf.loglikelihood(ctx, doc.completion)
68
69
70
        return is_prediction

    def process_results(self, doc, results):
Leo Gao's avatar
Leo Gao committed
71
        is_prediction, = results
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        return {
            "acc": is_prediction
        }

    def aggregation(self):
        return {
            "acc": mean,
        }

    def higher_is_better(self):
        return {
            "acc": True
        }


class Arithmetic2DPlus(Arithmetic):
88
89
    def get_file_download_info(self):
        return 'two_digit_addition.jsonl', '75a54b7a3db3b23369df74fe440c23025f3d3c51f664300bd3d56632b2617b3d'
90
91

class Arithmetic2DMinus(Arithmetic):
92
93
    def get_file_download_info(self):
        return 'two_digit_subtraction.jsonl', 'da956066ff108c00b341d360567472784f5fd872d6465071b44a14291205bc03'
94
95

class Arithmetic3DPlus(Arithmetic):
96
97
    def get_file_download_info(self):
        return 'three_digit_addition.jsonl', '124865e30efd2abfbc1855dd34c218fc02d32d780ace970ab9b4ea3fa74c798b'
98
99

class Arithmetic3DMinus(Arithmetic):
100
101
    def get_file_download_info(self):
        return 'three_digit_subtraction.jsonl', '7fc6aaedcb0e2bd17c398dd4147c5585b1e608278a8e98b914e69656707d6a29'
102
103

class Arithmetic4DPlus(Arithmetic):
104
105
    def get_file_download_info(self):
        return 'four_digit_addition.jsonl', '459c6f75baa2e8d7cf50bdd07db6d0ca9133a6b137d95d09267db85b6e07f391'
106
107

class Arithmetic4DMinus(Arithmetic):
108
109
    def get_file_download_info(self):
        return 'four_digit_subtraction.jsonl', '0c47db40a10c052ef0cf732a9ef2edaa53d66377d43eb47a9c382d33a8af7102'
110
111

class Arithmetic5DPlus(Arithmetic):
112
113
    def get_file_download_info(self):
        return 'five_digit_addition.jsonl', '30ada42efe315b958c6e9649274005d3b720e50298e92c3a2d321f8996e58f54'
114
115

class Arithmetic5DMinus(Arithmetic):
116
117
    def get_file_download_info(self):
        return 'five_digit_subtraction.jsonl', '8b98ccfc943cbf9193bcf1984954aa0b1a4527016072d972a2b055cc1482ca3c'
118
119

class Arithmetic2DMultiplication(Arithmetic):
120
121
    def get_file_download_info(self):
        return 'two_digit_multiplication.jsonl', '5613d1d1cc3b2c03edc1990252247d34c10ec82944b2cdeb19e71b00f237f431'
122
123

class Arithmetic1DComposite(Arithmetic):
124
125
   def get_file_download_info(self):
        return 'single_digit_three_ops.jsonl', '08b34e3272a8ff1d4932d63f251519d14c485c38d582366e1e323d0b859c3925'