arithmetic.py 4.02 KB
Newer Older
1
2
import abc
import json
3
import os
4
from collections import namedtuple
&'s avatar
& committed
5
6
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
7
from best_download import download_file
8

9
ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion'])
10

11

12
class Arithmetic(Task):
Leo Gao's avatar
Leo Gao committed
13
    VERSION = 0
14
15
16
    directory = 'data/arithmetic/'

    def __init__(self):
17
        super().__init__()
18
19
20
21
22
23
24

    def download(self):
        file_name, checksum = self.get_file_download_info()
        url = 'https://raw.githubusercontent.com/openai/gpt-3/master/data/' + file_name
        if not os.path.exists(self.directory):
            os.makedirs(self.directory)
        download_file(url, self.directory+file_name, checksum)
25
        self.set_docs()
26
27

    @abc.abstractmethod
28
29
    def get_file_download_info(self):
        """returns a tuple of (file_name, checksum)"""
30
31
        pass

32
33
34
35
36
    def set_docs(self):
        file_name, _ = self.get_file_download_info()
        jsons = open(self.directory+file_name, 'r')
        self._docs = [self.load_doc(json.loads(line)) for line in jsons]

37
    def has_training_docs(self):
Leo Gao's avatar
Leo Gao committed
38
        return False
39
40
41
42
43
44
45
46

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
Leo Gao's avatar
Leo Gao committed
47
        return NotImplemented
48
49

    def validation_docs(self):
Leo Gao's avatar
Leo Gao committed
50
        return self._docs
51
52
53
54
55

    def test_docs(self):
        return NotImplemented
    
    def doc_to_text(self, doc):
56
        return doc.context
57
58

    def doc_to_target(self, doc):
59
        return doc.completion
60

61
    def load_doc(self, doc_json):
Leo Gao's avatar
Leo Gao committed
62
63
64
65
        return ArithmeticDoc(context=doc_json['context'].strip()
            .replace('\n\n', '\n')
            .replace('Q:', 'Question:')
            .replace('A:', 'Answer:'), completion=doc_json['completion'])
66
    
67
    def construct_requests(self, doc, ctx):
68
        ll, is_prediction = rf.loglikelihood(ctx, doc.completion)
69
70
71
        return is_prediction

    def process_results(self, doc, results):
Leo Gao's avatar
Leo Gao committed
72
        is_prediction, = results
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        return {
            "acc": is_prediction
        }

    def aggregation(self):
        return {
            "acc": mean,
        }

    def higher_is_better(self):
        return {
            "acc": True
        }


class Arithmetic2DPlus(Arithmetic):
89
90
    def get_file_download_info(self):
        return 'two_digit_addition.jsonl', '75a54b7a3db3b23369df74fe440c23025f3d3c51f664300bd3d56632b2617b3d'
91
92

class Arithmetic2DMinus(Arithmetic):
93
94
    def get_file_download_info(self):
        return 'two_digit_subtraction.jsonl', 'da956066ff108c00b341d360567472784f5fd872d6465071b44a14291205bc03'
95
96

class Arithmetic3DPlus(Arithmetic):
97
98
    def get_file_download_info(self):
        return 'three_digit_addition.jsonl', '124865e30efd2abfbc1855dd34c218fc02d32d780ace970ab9b4ea3fa74c798b'
99
100

class Arithmetic3DMinus(Arithmetic):
101
102
    def get_file_download_info(self):
        return 'three_digit_subtraction.jsonl', '7fc6aaedcb0e2bd17c398dd4147c5585b1e608278a8e98b914e69656707d6a29'
103
104

class Arithmetic4DPlus(Arithmetic):
105
106
    def get_file_download_info(self):
        return 'four_digit_addition.jsonl', '459c6f75baa2e8d7cf50bdd07db6d0ca9133a6b137d95d09267db85b6e07f391'
107
108

class Arithmetic4DMinus(Arithmetic):
109
110
    def get_file_download_info(self):
        return 'four_digit_subtraction.jsonl', '0c47db40a10c052ef0cf732a9ef2edaa53d66377d43eb47a9c382d33a8af7102'
111
112

class Arithmetic5DPlus(Arithmetic):
113
114
    def get_file_download_info(self):
        return 'five_digit_addition.jsonl', '30ada42efe315b958c6e9649274005d3b720e50298e92c3a2d321f8996e58f54'
115
116

class Arithmetic5DMinus(Arithmetic):
117
118
    def get_file_download_info(self):
        return 'five_digit_subtraction.jsonl', '8b98ccfc943cbf9193bcf1984954aa0b1a4527016072d972a2b055cc1482ca3c'
119
120

class Arithmetic2DMultiplication(Arithmetic):
121
122
    def get_file_download_info(self):
        return 'two_digit_multiplication.jsonl', '5613d1d1cc3b2c03edc1990252247d34c10ec82944b2cdeb19e71b00f237f431'
123
124

class Arithmetic1DComposite(Arithmetic):
125
126
   def get_file_download_info(self):
        return 'single_digit_three_ops.jsonl', '08b34e3272a8ff1d4932d63f251519d14c485c38d582366e1e323d0b859c3925'