arithmetic.py 5.31 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
"""
Language Models are Few-Shot Learners
https://arxiv.org/pdf/2005.14165.pdf

A small battery of 10 tests that involve asking language models a simple arithmetic
problem in natural language.

Homepage: https://github.com/openai/gpt-3/tree/master/data

@inproceedings{NEURIPS2020_1457c0d6,
 author = {Brown, Tom and Mann, Benjamin and Ryder, Nick and Subbiah, Melanie and Kaplan, Jared D and Dhariwal, Prafulla and Neelakantan, Arvind and Shyam, Pranav and Sastry, Girish and Askell, Amanda and Agarwal, Sandhini and Herbert-Voss, Ariel and Krueger, Gretchen and Henighan, Tom and Child, Rewon and Ramesh, Aditya and Ziegler, Daniel and Wu, Jeffrey and Winter, Clemens and Hesse, Chris and Chen, Mark and Sigler, Eric and Litwin, Mateusz and Gray, Scott and Chess, Benjamin and Clark, Jack and Berner, Christopher and McCandlish, Sam and Radford, Alec and Sutskever, Ilya and Amodei, Dario},
 booktitle = {Advances in Neural Information Processing Systems},
 editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},
 pages = {1877--1901},
 publisher = {Curran Associates, Inc.},
 title = {Language Models are Few-Shot Learners},
 url = {https://proceedings.neurips.cc/paper/2020/file/1457c0d6bfcb4967418bfb8ac142f64a-Paper.pdf},
 volume = {33},
 year = {2020}
}
"""
22
23
import abc
import json
24
import os
25
from collections import namedtuple
&'s avatar
& committed
26
27
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
28
from best_download import download_file
29

30
ArithmeticDoc = namedtuple('ArithmeticDoc', ['context', 'completion'])
31

32

33
class Arithmetic(Task):
Leo Gao's avatar
Leo Gao committed
34
    VERSION = 0
35
36
37
    directory = 'data/arithmetic/'

    def __init__(self):
38
        super().__init__()
39
40
41
42
43
44

    def download(self):
        file_name, checksum = self.get_file_download_info()
        url = 'https://raw.githubusercontent.com/openai/gpt-3/master/data/' + file_name
        if not os.path.exists(self.directory):
            os.makedirs(self.directory)
45
        download_file(url, local_file=self.directory+file_name, expected_checksum=checksum)
46
        self.set_docs()
47
48

    @abc.abstractmethod
49
50
    def get_file_download_info(self):
        """returns a tuple of (file_name, checksum)"""
51
52
        pass

53
54
55
56
57
    def set_docs(self):
        file_name, _ = self.get_file_download_info()
        jsons = open(self.directory+file_name, 'r')
        self._docs = [self.load_doc(json.loads(line)) for line in jsons]

58
    def has_training_docs(self):
Leo Gao's avatar
Leo Gao committed
59
        return False
60
61
62
63
64
65
66
67

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
Leo Gao's avatar
Leo Gao committed
68
        return NotImplemented
69
70

    def validation_docs(self):
Leo Gao's avatar
Leo Gao committed
71
        return self._docs
72
73
74
75
76

    def test_docs(self):
        return NotImplemented
    
    def doc_to_text(self, doc):
77
        return doc.context
78
79

    def doc_to_target(self, doc):
80
        return doc.completion
81

82
    def load_doc(self, doc_json):
Leo Gao's avatar
Leo Gao committed
83
84
85
86
        return ArithmeticDoc(context=doc_json['context'].strip()
            .replace('\n\n', '\n')
            .replace('Q:', 'Question:')
            .replace('A:', 'Answer:'), completion=doc_json['completion'])
87
    
88
    def construct_requests(self, doc, ctx):
89
        ll, is_prediction = rf.loglikelihood(ctx, doc.completion)
90
91
92
        return is_prediction

    def process_results(self, doc, results):
Leo Gao's avatar
Leo Gao committed
93
        is_prediction, = results
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        return {
            "acc": is_prediction
        }

    def aggregation(self):
        return {
            "acc": mean,
        }

    def higher_is_better(self):
        return {
            "acc": True
        }


class Arithmetic2DPlus(Arithmetic):
110
111
    def get_file_download_info(self):
        return 'two_digit_addition.jsonl', '75a54b7a3db3b23369df74fe440c23025f3d3c51f664300bd3d56632b2617b3d'
112
113

class Arithmetic2DMinus(Arithmetic):
114
115
    def get_file_download_info(self):
        return 'two_digit_subtraction.jsonl', 'da956066ff108c00b341d360567472784f5fd872d6465071b44a14291205bc03'
116
117

class Arithmetic3DPlus(Arithmetic):
118
119
    def get_file_download_info(self):
        return 'three_digit_addition.jsonl', '124865e30efd2abfbc1855dd34c218fc02d32d780ace970ab9b4ea3fa74c798b'
120
121

class Arithmetic3DMinus(Arithmetic):
122
123
    def get_file_download_info(self):
        return 'three_digit_subtraction.jsonl', '7fc6aaedcb0e2bd17c398dd4147c5585b1e608278a8e98b914e69656707d6a29'
124
125

class Arithmetic4DPlus(Arithmetic):
126
127
    def get_file_download_info(self):
        return 'four_digit_addition.jsonl', '459c6f75baa2e8d7cf50bdd07db6d0ca9133a6b137d95d09267db85b6e07f391'
128
129

class Arithmetic4DMinus(Arithmetic):
130
131
    def get_file_download_info(self):
        return 'four_digit_subtraction.jsonl', '0c47db40a10c052ef0cf732a9ef2edaa53d66377d43eb47a9c382d33a8af7102'
132
133

class Arithmetic5DPlus(Arithmetic):
134
135
    def get_file_download_info(self):
        return 'five_digit_addition.jsonl', '30ada42efe315b958c6e9649274005d3b720e50298e92c3a2d321f8996e58f54'
136
137

class Arithmetic5DMinus(Arithmetic):
138
139
    def get_file_download_info(self):
        return 'five_digit_subtraction.jsonl', '8b98ccfc943cbf9193bcf1984954aa0b1a4527016072d972a2b055cc1482ca3c'
140
141

class Arithmetic2DMultiplication(Arithmetic):
142
143
    def get_file_download_info(self):
        return 'two_digit_multiplication.jsonl', '5613d1d1cc3b2c03edc1990252247d34c10ec82944b2cdeb19e71b00f237f431'
144
145

class Arithmetic1DComposite(Arithmetic):
146
147
   def get_file_download_info(self):
        return 'single_digit_three_ops.jsonl', '08b34e3272a8ff1d4932d63f251519d14c485c38d582366e1e323d0b859c3925'