test_hf_argparser.py 9.37 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
16
17
18
19
import argparse
import unittest
from argparse import Namespace
from dataclasses import dataclass, field
from enum import Enum
20
from typing import List, Optional
21

Julien Chaumond's avatar
Julien Chaumond committed
22
from transformers import HfArgumentParser, TrainingArguments
23
from transformers.hf_argparser import string_to_bool
24
25


26
27
28
29
def list_field(default=None, metadata=None):
    return field(default_factory=lambda: default, metadata=metadata)


30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
@dataclass
class BasicExample:
    foo: int
    bar: float
    baz: str
    flag: bool


@dataclass
class WithDefaultExample:
    foo: int = 42
    baz: str = field(default="toto", metadata={"help": "help message"})


@dataclass
class WithDefaultBoolExample:
    foo: bool = False
    baz: bool = True
48
    opt: Optional[bool] = None
49
50
51
52
53
54
55
56
57


class BasicEnum(Enum):
    titi = "titi"
    toto = "toto"


@dataclass
class EnumExample:
58
59
60
61
    foo: BasicEnum = "toto"

    def __post_init__(self):
        self.foo = BasicEnum(self.foo)
62
63
64
65
66
67
68


@dataclass
class OptionalExample:
    foo: Optional[int] = None
    bar: Optional[float] = field(default=None, metadata={"help": "help message"})
    baz: Optional[str] = None
69
70
71
72
73
74
75
76
77
78
    ces: Optional[List[str]] = list_field(default=[])
    des: Optional[List[int]] = list_field(default=[])


@dataclass
class ListExample:
    foo_int: List[int] = list_field(default=[])
    bar_int: List[int] = list_field(default=[1, 2, 3])
    foo_str: List[str] = list_field(default=["Hallo", "Bonjour", "Hello"])
    foo_float: List[float] = list_field(default=[0.1, 0.2, 0.3])
79
80


81
82
83
84
85
86
87
88
89
90
@dataclass
class RequiredExample:
    required_list: List[int] = field()
    required_str: str = field()
    required_enum: BasicEnum = field()

    def __post_init__(self):
        self.required_enum = BasicEnum(self.required_enum)


91
92
93
94
95
96
97
98
99
@dataclass
class StringLiteralAnnotationExample:
    foo: int
    required_enum: "BasicEnum" = field()
    opt: "Optional[bool]" = None
    baz: "str" = field(default="toto", metadata={"help": "help message"})
    foo_str: "List[str]" = list_field(default=["Hallo", "Bonjour", "Hello"])


100
class HfArgumentParserTest(unittest.TestCase):
101
    def argparsersEqual(self, a: argparse.ArgumentParser, b: argparse.ArgumentParser):
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        """
        Small helper to check pseudo-equality of parsed arguments on `ArgumentParser` instances.
        """
        self.assertEqual(len(a._actions), len(b._actions))
        for x, y in zip(a._actions, b._actions):
            xx = {k: v for k, v in vars(x).items() if k != "container"}
            yy = {k: v for k, v in vars(y).items() if k != "container"}
            self.assertEqual(xx, yy)

    def test_basic(self):
        parser = HfArgumentParser(BasicExample)

        expected = argparse.ArgumentParser()
        expected.add_argument("--foo", type=int, required=True)
        expected.add_argument("--bar", type=float, required=True)
        expected.add_argument("--baz", type=str, required=True)
118
        expected.add_argument("--flag", type=string_to_bool, default=False, const=True, nargs="?")
119
120
        self.argparsersEqual(parser, expected)

121
122
123
124
        args = ["--foo", "1", "--baz", "quux", "--bar", "0.5"]
        (example,) = parser.parse_args_into_dataclasses(args, look_for_args_file=False)
        self.assertFalse(example.flag)

125
126
127
128
129
130
131
132
133
134
135
136
    def test_with_default(self):
        parser = HfArgumentParser(WithDefaultExample)

        expected = argparse.ArgumentParser()
        expected.add_argument("--foo", default=42, type=int)
        expected.add_argument("--baz", default="toto", type=str, help="help message")
        self.argparsersEqual(parser, expected)

    def test_with_default_bool(self):
        parser = HfArgumentParser(WithDefaultBoolExample)

        expected = argparse.ArgumentParser()
137
138
        expected.add_argument("--foo", type=string_to_bool, default=False, const=True, nargs="?")
        expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?")
139
140
141
        # A boolean no_* argument always has to come after its "default: True" regular counter-part
        # and its default must be set to False
        expected.add_argument("--no_baz", action="store_false", default=False, dest="baz")
142
        expected.add_argument("--opt", type=string_to_bool, default=None)
143
144
145
        self.argparsersEqual(parser, expected)

        args = parser.parse_args([])
146
        self.assertEqual(args, Namespace(foo=False, baz=True, opt=None))
147

148
        args = parser.parse_args(["--foo", "--no_baz"])
149
150
151
152
153
154
155
156
157
158
        self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))

        args = parser.parse_args(["--foo", "--baz"])
        self.assertEqual(args, Namespace(foo=True, baz=True, opt=None))

        args = parser.parse_args(["--foo", "True", "--baz", "True", "--opt", "True"])
        self.assertEqual(args, Namespace(foo=True, baz=True, opt=True))

        args = parser.parse_args(["--foo", "False", "--baz", "False", "--opt", "False"])
        self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
159
160
161
162
163

    def test_with_enum(self):
        parser = HfArgumentParser(EnumExample)

        expected = argparse.ArgumentParser()
164
        expected.add_argument("--foo", default="toto", choices=["titi", "toto"], type=str)
165
166
167
        self.argparsersEqual(parser, expected)

        args = parser.parse_args([])
168
169
170
        self.assertEqual(args.foo, "toto")
        enum_ex = parser.parse_args_into_dataclasses([])[0]
        self.assertEqual(enum_ex.foo, BasicEnum.toto)
171
172

        args = parser.parse_args(["--foo", "titi"])
173
174
175
        self.assertEqual(args.foo, "titi")
        enum_ex = parser.parse_args_into_dataclasses(["--foo", "titi"])[0]
        self.assertEqual(enum_ex.foo, BasicEnum.titi)
176

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    def test_with_list(self):
        parser = HfArgumentParser(ListExample)

        expected = argparse.ArgumentParser()
        expected.add_argument("--foo_int", nargs="+", default=[], type=int)
        expected.add_argument("--bar_int", nargs="+", default=[1, 2, 3], type=int)
        expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
        expected.add_argument("--foo_float", nargs="+", default=[0.1, 0.2, 0.3], type=float)

        self.argparsersEqual(parser, expected)

        args = parser.parse_args([])
        self.assertEqual(
            args,
            Namespace(foo_int=[], bar_int=[1, 2, 3], foo_str=["Hallo", "Bonjour", "Hello"], foo_float=[0.1, 0.2, 0.3]),
        )

        args = parser.parse_args("--foo_int 1 --bar_int 2 3 --foo_str a b c --foo_float 0.1 0.7".split())
        self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))

197
198
199
200
201
202
203
    def test_with_optional(self):
        parser = HfArgumentParser(OptionalExample)

        expected = argparse.ArgumentParser()
        expected.add_argument("--foo", default=None, type=int)
        expected.add_argument("--bar", default=None, type=float, help="help message")
        expected.add_argument("--baz", default=None, type=str)
204
205
        expected.add_argument("--ces", nargs="+", default=[], type=str)
        expected.add_argument("--des", nargs="+", default=[], type=int)
206
207
208
        self.argparsersEqual(parser, expected)

        args = parser.parse_args([])
209
        self.assertEqual(args, Namespace(foo=None, bar=None, baz=None, ces=[], des=[]))
210

211
212
        args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split())
        self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3]))
213

214
215
216
217
218
219
220
221
222
    def test_with_required(self):
        parser = HfArgumentParser(RequiredExample)

        expected = argparse.ArgumentParser()
        expected.add_argument("--required_list", nargs="+", type=int, required=True)
        expected.add_argument("--required_str", type=str, required=True)
        expected.add_argument("--required_enum", type=str, choices=["titi", "toto"], required=True)
        self.argparsersEqual(parser, expected)

223
224
225
226
227
228
229
230
231
232
233
    def test_with_string_literal_annotation(self):
        parser = HfArgumentParser(StringLiteralAnnotationExample)

        expected = argparse.ArgumentParser()
        expected.add_argument("--foo", type=int, required=True)
        expected.add_argument("--required_enum", type=str, choices=["titi", "toto"], required=True)
        expected.add_argument("--opt", type=string_to_bool, default=None)
        expected.add_argument("--baz", default="toto", type=str, help="help message")
        expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
        self.argparsersEqual(parser, expected)

234
235
236
237
238
239
240
241
242
243
244
245
246
247
    def test_parse_dict(self):
        parser = HfArgumentParser(BasicExample)

        args_dict = {
            "foo": 12,
            "bar": 3.14,
            "baz": "42",
            "flag": True,
        }

        parsed_args = parser.parse_dict(args_dict)[0]
        args = BasicExample(**args_dict)
        self.assertEqual(parsed_args, args)

248
249
250
    def test_integration_training_args(self):
        parser = HfArgumentParser(TrainingArguments)
        self.assertIsNotNone(parser)