arithmetic.py 5.36 KB
Newer Older
1
2
3
4
5
6
7
8
9
"""
Language Models are Few-Shot Learners
https://arxiv.org/pdf/2005.14165.pdf

A small battery of 10 tests that involve asking language models a simple arithmetic
problem in natural language.

Homepage: https://github.com/openai/gpt-3/tree/master/data
"""
10
11
import abc
import json
12
import os
13
from collections import namedtuple
&'s avatar
& committed
14
15
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
16
from best_download import download_file
17

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

_CITATION = """
@inproceedings{NEURIPS2020_1457c0d6,
    author = {Brown, Tom and Mann, Benjamin and Ryder, Nick and Subbiah, Melanie and Kaplan, Jared D and Dhariwal, Prafulla and Neelakantan, Arvind and Shyam, Pranav and Sastry, Girish and Askell, Amanda and Agarwal, Sandhini and Herbert-Voss, Ariel and Krueger, Gretchen and Henighan, Tom and Child, Rewon and Ramesh, Aditya and Ziegler, Daniel and Wu, Jeffrey and Winter, Clemens and Hesse, Chris and Chen, Mark and Sigler, Eric and Litwin, Mateusz and Gray, Scott and Chess, Benjamin and Clark, Jack and Berner, Christopher and McCandlish, Sam and Radford, Alec and Sutskever, Ilya and Amodei, Dario},
    booktitle = {Advances in Neural Information Processing Systems},
    editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},
    pages = {1877--1901},
    publisher = {Curran Associates, Inc.},
    title = {Language Models are Few-Shot Learners},
    url = {https://proceedings.neurips.cc/paper/2020/file/1457c0d6bfcb4967418bfb8ac142f64a-Paper.pdf},
    volume = {33},
    year = {2020}
}
"""


34
ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion'])
35

36

37
class Arithmetic(Task):
Leo Gao's avatar
Leo Gao committed
38
    VERSION = 0
39
40
41
    directory = 'data/arithmetic/'

    def __init__(self):
42
        super().__init__()
43
44
45
46
47
48

    def download(self):
        file_name, checksum = self.get_file_download_info()
        url = 'https://raw.githubusercontent.com/openai/gpt-3/master/data/' + file_name
        if not os.path.exists(self.directory):
            os.makedirs(self.directory)
49
        download_file(url, local_file=self.directory+file_name, expected_checksum=checksum)
50
        self.set_docs()
51
52

    @abc.abstractmethod
53
54
    def get_file_download_info(self):
        """returns a tuple of (file_name, checksum)"""
55
56
        pass

57
58
59
60
61
    def set_docs(self):
        file_name, _ = self.get_file_download_info()
        jsons = open(self.directory+file_name, 'r')
        self._docs = [self.load_doc(json.loads(line)) for line in jsons]

62
    def has_training_docs(self):
Leo Gao's avatar
Leo Gao committed
63
        return False
64
65
66
67
68
69
70
71

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
Leo Gao's avatar
Leo Gao committed
72
        return NotImplemented
73
74

    def validation_docs(self):
Leo Gao's avatar
Leo Gao committed
75
        return self._docs
76
77
78
79
80

    def test_docs(self):
        return NotImplemented
    
    def doc_to_text(self, doc):
81
        return doc.context
82
83

    def doc_to_target(self, doc):
84
        return doc.completion
85

86
    def load_doc(self, doc_json):
Leo Gao's avatar
Leo Gao committed
87
88
89
90
        return ArithmeticDoc(context=doc_json['context'].strip()
            .replace('\n\n', '\n')
            .replace('Q:', 'Question:')
            .replace('A:', 'Answer:'), completion=doc_json['completion'])
91
    
92
    def construct_requests(self, doc, ctx):
93
        ll, is_prediction = rf.loglikelihood(ctx, doc.completion)
94
95
96
        return is_prediction

    def process_results(self, doc, results):
Leo Gao's avatar
Leo Gao committed
97
        is_prediction, = results
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        return {
            "acc": is_prediction
        }

    def aggregation(self):
        return {
            "acc": mean,
        }

    def higher_is_better(self):
        return {
            "acc": True
        }


class Arithmetic2DPlus(Arithmetic):
114
115
    def get_file_download_info(self):
        return 'two_digit_addition.jsonl', '75a54b7a3db3b23369df74fe440c23025f3d3c51f664300bd3d56632b2617b3d'
116
117

class Arithmetic2DMinus(Arithmetic):
118
119
    def get_file_download_info(self):
        return 'two_digit_subtraction.jsonl', 'da956066ff108c00b341d360567472784f5fd872d6465071b44a14291205bc03'
120
121

class Arithmetic3DPlus(Arithmetic):
122
123
    def get_file_download_info(self):
        return 'three_digit_addition.jsonl', '124865e30efd2abfbc1855dd34c218fc02d32d780ace970ab9b4ea3fa74c798b'
124
125

class Arithmetic3DMinus(Arithmetic):
126
127
    def get_file_download_info(self):
        return 'three_digit_subtraction.jsonl', '7fc6aaedcb0e2bd17c398dd4147c5585b1e608278a8e98b914e69656707d6a29'
128
129

class Arithmetic4DPlus(Arithmetic):
130
131
    def get_file_download_info(self):
        return 'four_digit_addition.jsonl', '459c6f75baa2e8d7cf50bdd07db6d0ca9133a6b137d95d09267db85b6e07f391'
132
133

class Arithmetic4DMinus(Arithmetic):
134
135
    def get_file_download_info(self):
        return 'four_digit_subtraction.jsonl', '0c47db40a10c052ef0cf732a9ef2edaa53d66377d43eb47a9c382d33a8af7102'
136
137

class Arithmetic5DPlus(Arithmetic):
138
139
    def get_file_download_info(self):
        return 'five_digit_addition.jsonl', '30ada42efe315b958c6e9649274005d3b720e50298e92c3a2d321f8996e58f54'
140
141

class Arithmetic5DMinus(Arithmetic):
142
143
    def get_file_download_info(self):
        return 'five_digit_subtraction.jsonl', '8b98ccfc943cbf9193bcf1984954aa0b1a4527016072d972a2b055cc1482ca3c'
144
145

class Arithmetic2DMultiplication(Arithmetic):
146
147
    def get_file_download_info(self):
        return 'two_digit_multiplication.jsonl', '5613d1d1cc3b2c03edc1990252247d34c10ec82944b2cdeb19e71b00f237f431'
148
149

class Arithmetic1DComposite(Arithmetic):
150
151
   def get_file_download_info(self):
        return 'single_digit_three_ops.jsonl', '08b34e3272a8ff1d4932d63f251519d14c485c38d582366e1e323d0b859c3925'