gsm8k.py 3.74 KB
Newer Older
Jonathan Tow's avatar
Jonathan Tow committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Grade School Math 8k dataset."""


import json

import datasets


_CITATION = """\
@misc{cobbe2021training,
      title={Training Verifiers to Solve Math Word Problems},
      author={Karl Cobbe and Vineet Kosaraju and Mohammad Bavarian and Jacob Hilton and Reiichiro Nakano and Christopher Hesse and John Schulman},
      year={2021},
      eprint={2110.14168},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
"""

_DESCRIPTION = """\
Fabrizio Milo's avatar
Fabrizio Milo committed
34
35
State-of-the-art language models can match human performance on many tasks, but
they still struggle to robustly perform multi-step mathematical reasoning. To
Jonathan Tow's avatar
Jonathan Tow committed
36
37
diagnose the failures of current models and support research, we introduce GSM8K,
a dataset of 8.5K high quality linguistically diverse grade school math word problems.
Fabrizio Milo's avatar
Fabrizio Milo committed
38
We find that even the largest transformer models fail to achieve high test performance,
Jonathan Tow's avatar
Jonathan Tow committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
despite the conceptual simplicity of this problem distribution.
"""

_HOMEPAGE = "https://github.com/openai/grade-school-math"

# TODO: Add the licence for the dataset here if you can find it
_LICENSE = ""

_URLS = {
    "train": "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl",
    "test": "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl",
}


class GSM8K(datasets.GeneratorBasedBuilder):
    """Grade School Math 8k"""

    VERSION = datasets.Version("0.0.1")

    BUILDER_CONFIGS = [
Fabrizio Milo's avatar
Fabrizio Milo committed
59
60
61
62
63
        datasets.BuilderConfig(
            name="gsm8k",
            version=VERSION,
            description="The Grade School Math 8k dataset.",
        ),
Jonathan Tow's avatar
Jonathan Tow committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    ]

    def _info(self):
        features = datasets.Features(
            {
                "question": datasets.Value("string"),
                "answer": datasets.Value("string"),
            }
        )
        return datasets.DatasetInfo(
            description=_DESCRIPTION,
            features=features,
            homepage=_HOMEPAGE,
            license=_LICENSE,
            citation=_CITATION,
        )

    def _split_generators(self, dl_manager):
        urls = {"train": _URLS["train"], "test": _URLS["test"]}
        data_dir = dl_manager.download_and_extract(urls)
        return [
            datasets.SplitGenerator(
                name=datasets.Split.TRAIN,
                # These kwargs will be passed to _generate_examples
                gen_kwargs={
                    "filepath": data_dir["train"],
                    "split": "train",
                },
            ),
            datasets.SplitGenerator(
                name=datasets.Split.TEST,
                # These kwargs will be passed to _generate_examples
Fabrizio Milo's avatar
Fabrizio Milo committed
96
                gen_kwargs={"filepath": data_dir["test"], "split": "test"},
Jonathan Tow's avatar
Jonathan Tow committed
97
98
99
100
101
102
103
104
105
106
107
108
            ),
        ]

    # method parameters are unpacked from `gen_kwargs` as given in `_split_generators`
    def _generate_examples(self, filepath, split):
        with open(filepath, encoding="utf-8") as f:
            for key, row in enumerate(f):
                data = json.loads(row)
                yield key, {
                    "question": data["question"],
                    "answer": data["answer"],
                }