arithmetic.py 4 KB
Newer Older
1
2
import abc
import json
3
import os
4
from collections import namedtuple
&'s avatar
& committed
5
6
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
7
from best_download import download_file
8

9
ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion'])
10

11
class Arithmetic(Task):
12
13
14
    directory = 'data/arithmetic/'

    def __init__(self):
15
        super().__init__()
16
17
18
19
20
21
22

    def download(self):
        file_name, checksum = self.get_file_download_info()
        url = 'https://raw.githubusercontent.com/openai/gpt-3/master/data/' + file_name
        if not os.path.exists(self.directory):
            os.makedirs(self.directory)
        download_file(url, self.directory+file_name, checksum)
23
        self.set_docs()
24
25

    @abc.abstractmethod
26
27
    def get_file_download_info(self):
        """returns a tuple of (file_name, checksum)"""
28
29
        pass

30
31
32
33
34
    def set_docs(self):
        file_name, _ = self.get_file_download_info()
        jsons = open(self.directory+file_name, 'r')
        self._docs = [self.load_doc(json.loads(line)) for line in jsons]

35
    def has_training_docs(self):
Leo Gao's avatar
Leo Gao committed
36
        return False
37
38
39
40
41
42
43
44

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
Leo Gao's avatar
Leo Gao committed
45
        return NotImplemented
46
47

    def validation_docs(self):
Leo Gao's avatar
Leo Gao committed
48
        return self._docs
49
50
51
52
53

    def test_docs(self):
        return NotImplemented
    
    def doc_to_text(self, doc):
54
        return doc.context
55
56

    def doc_to_target(self, doc):
57
        return doc.completion
58

59
    def load_doc(self, doc_json):
Leo Gao's avatar
Leo Gao committed
60
61
62
63
        return ArithmeticDoc(context=doc_json['context'].strip()
            .replace('\n\n', '\n')
            .replace('Q:', 'Question:')
            .replace('A:', 'Answer:'), completion=doc_json['completion'])
64
    
65
    def construct_requests(self, doc, ctx):
66
        ll, is_prediction = rf.loglikelihood(ctx, doc.completion)
67
68
69
        return is_prediction

    def process_results(self, doc, results):
Leo Gao's avatar
Leo Gao committed
70
        is_prediction, = results
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        return {
            "acc": is_prediction
        }

    def aggregation(self):
        return {
            "acc": mean,
        }

    def higher_is_better(self):
        return {
            "acc": True
        }


class Arithmetic2DPlus(Arithmetic):
87
88
    def get_file_download_info(self):
        return 'two_digit_addition.jsonl', '75a54b7a3db3b23369df74fe440c23025f3d3c51f664300bd3d56632b2617b3d'
89
90

class Arithmetic2DMinus(Arithmetic):
91
92
    def get_file_download_info(self):
        return 'two_digit_subtraction.jsonl', 'da956066ff108c00b341d360567472784f5fd872d6465071b44a14291205bc03'
93
94

class Arithmetic3DPlus(Arithmetic):
95
96
    def get_file_download_info(self):
        return 'three_digit_addition.jsonl', '124865e30efd2abfbc1855dd34c218fc02d32d780ace970ab9b4ea3fa74c798b'
97
98

class Arithmetic3DMinus(Arithmetic):
99
100
    def get_file_download_info(self):
        return 'three_digit_subtraction.jsonl', '7fc6aaedcb0e2bd17c398dd4147c5585b1e608278a8e98b914e69656707d6a29'
101
102

class Arithmetic4DPlus(Arithmetic):
103
104
    def get_file_download_info(self):
        return 'four_digit_addition.jsonl', '459c6f75baa2e8d7cf50bdd07db6d0ca9133a6b137d95d09267db85b6e07f391'
105
106

class Arithmetic4DMinus(Arithmetic):
107
108
    def get_file_download_info(self):
        return 'four_digit_subtraction.jsonl', '0c47db40a10c052ef0cf732a9ef2edaa53d66377d43eb47a9c382d33a8af7102'
109
110

class Arithmetic5DPlus(Arithmetic):
111
112
    def get_file_download_info(self):
        return 'five_digit_addition.jsonl', '30ada42efe315b958c6e9649274005d3b720e50298e92c3a2d321f8996e58f54'
113
114

class Arithmetic5DMinus(Arithmetic):
115
116
    def get_file_download_info(self):
        return 'five_digit_subtraction.jsonl', '8b98ccfc943cbf9193bcf1984954aa0b1a4527016072d972a2b055cc1482ca3c'
117
118

class Arithmetic2DMultiplication(Arithmetic):
119
120
    def get_file_download_info(self):
        return 'two_digit_multiplication.jsonl', '5613d1d1cc3b2c03edc1990252247d34c10ec82944b2cdeb19e71b00f237f431'
121
122

class Arithmetic1DComposite(Arithmetic):
123
124
   def get_file_download_info(self):
        return 'single_digit_three_ops.jsonl', '08b34e3272a8ff1d4932d63f251519d14c485c38d582366e1e323d0b859c3925'