"tests/models/beit/test_image_processing_beit.py" did not exist on "2e3452af0f53fb1b023c70a4596de48efcbfb819"
hf_argparser.py 9.97 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import dataclasses
16
17
import json
import sys
18
from argparse import ArgumentParser, ArgumentTypeError
19
from enum import Enum
20
from pathlib import Path
Sylvain Gugger's avatar
Sylvain Gugger committed
21
from typing import Any, Iterable, List, NewType, Optional, Tuple, Union
22
23
24
25
26
27


DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)


28
29
30
31
32
33
34
35
36
37
38
39
40
41
# From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
def string_to_bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise ArgumentTypeError(
            f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
        )


42
43
class HfArgumentParser(ArgumentParser):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
44
    This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
45

Sylvain Gugger's avatar
Sylvain Gugger committed
46
47
48
    The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed)
    arguments to the parser after initialization and you'll get the output back after parsing as an additional
    namespace.
49
50
51
52
53
54
55
56
    """

    dataclass_types: Iterable[DataClassType]

    def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
        """
        Args:
            dataclass_types:
Sylvain Gugger's avatar
Sylvain Gugger committed
57
                Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.
58
59
60
61
62
63
64
65
66
67
68
69
            kwargs:
                (Optional) Passed to `argparse.ArgumentParser()` in the regular way.
        """
        super().__init__(**kwargs)
        if dataclasses.is_dataclass(dataclass_types):
            dataclass_types = [dataclass_types]
        self.dataclass_types = dataclass_types
        for dtype in self.dataclass_types:
            self._add_dataclass_arguments(dtype)

    def _add_dataclass_arguments(self, dtype: DataClassType):
        for field in dataclasses.fields(dtype):
70
71
            if not field.init:
                continue
72
73
74
75
76
77
78
79
80
81
82
            field_name = f"--{field.name}"
            kwargs = field.metadata.copy()
            # field.metadata is not used at all by Data Classes,
            # it is provided as a third-party extension mechanism.
            if isinstance(field.type, str):
                raise ImportError(
                    "This implementation is not compatible with Postponed Evaluation of Annotations (PEP 563),"
                    "which can be opted in from Python 3.7 with `from __future__ import annotations`."
                    "We will add compatibility when Python 3.9 is released."
                )
            typestring = str(field.type)
83
84
            for prim_type in (int, float, str):
                for collection in (List,):
85
86
87
88
                    if (
                        typestring == f"typing.Union[{collection[prim_type]}, NoneType]"
                        or typestring == f"typing.Optional[{collection[prim_type]}]"
                    ):
89
                        field.type = collection[prim_type]
90
91
92
93
                if (
                    typestring == f"typing.Union[{prim_type.__name__}, NoneType]"
                    or typestring == f"typing.Optional[{prim_type.__name__}]"
                ):
94
95
                    field.type = prim_type

96
            if isinstance(field.type, type) and issubclass(field.type, Enum):
97
98
                kwargs["choices"] = [x.value for x in field.type]
                kwargs["type"] = type(kwargs["choices"][0])
99
100
                if field.default is not dataclasses.MISSING:
                    kwargs["default"] = field.default
Sylvain Gugger's avatar
Sylvain Gugger committed
101
            elif field.type is bool or field.type is Optional[bool]:
102
                if field.default is True:
103
104
105
106
107
108
109
110
111
112
113
114
115
                    self.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **kwargs)

                # Hack because type=bool in argparse does not behave as we want.
                kwargs["type"] = string_to_bool
                if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
                    # Default value is True if we have no default when of type bool.
                    default = True if field.default is dataclasses.MISSING else field.default
                    # This is the value that will get picked if we don't include --field_name in any way
                    kwargs["default"] = default
                    # This tells argparse we accept 0 or 1 value after --field_name
                    kwargs["nargs"] = "?"
                    # This is the value that will get picked if we do --field_name (without value)
                    kwargs["const"] = True
116
117
118
119
120
121
122
123
            elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List):
                kwargs["nargs"] = "+"
                kwargs["type"] = field.type.__args__[0]
                assert all(
                    x == kwargs["type"] for x in field.type.__args__
                ), "{} cannot be a List of mixed types".format(field.name)
                if field.default_factory is not dataclasses.MISSING:
                    kwargs["default"] = field.default_factory()
124
125
126
127
            else:
                kwargs["type"] = field.type
                if field.default is not dataclasses.MISSING:
                    kwargs["default"] = field.default
Julien Plu's avatar
Julien Plu committed
128
129
                elif field.default_factory is not dataclasses.MISSING:
                    kwargs["default"] = field.default_factory()
130
131
132
133
                else:
                    kwargs["required"] = True
            self.add_argument(field_name, **kwargs)

134
    def parse_args_into_dataclasses(
135
        self, args=None, return_remaining_strings=False, look_for_args_file=True, args_filename=None
136
    ) -> Tuple[DataClass, ...]:
137
138
139
        """
        Parse command-line args into instances of the specified dataclass types.

Sylvain Gugger's avatar
Sylvain Gugger committed
140
        This relies on argparse's `ArgumentParser.parse_known_args`. See the doc at:
141
142
143
144
        docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args

        Args:
            args:
Sylvain Gugger's avatar
Sylvain Gugger committed
145
                List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser)
146
147
            return_remaining_strings:
                If true, also return a list of remaining argument strings.
148
            look_for_args_file:
Sylvain Gugger's avatar
Sylvain Gugger committed
149
150
                If true, will look for a ".args" file with the same base name as the entry point script for this
                process, and will append its potential content to the command line args.
151
            args_filename:
Sylvain Gugger's avatar
Sylvain Gugger committed
152
                If not None, will uses this file instead of the ".args" file specified in the previous argument.
153
154
155

        Returns:
            Tuple consisting of:
Sylvain Gugger's avatar
Sylvain Gugger committed
156
157
158

                - the dataclass instances in the same order as they were passed to the initializer.abspath
                - if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser
159
                  after initialization.
Sylvain Gugger's avatar
Sylvain Gugger committed
160
                - The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args)
161
        """
162
163
164
165
166
167
        if args_filename or (look_for_args_file and len(sys.argv)):
            if args_filename:
                args_file = Path(args_filename)
            else:
                args_file = Path(sys.argv[0]).with_suffix(".args")

168
169
170
171
172
            if args_file.exists():
                fargs = args_file.read_text().split()
                args = fargs + args if args is not None else fargs + sys.argv[1:]
                # in case of duplicate arguments the first one has precedence
                # so we append rather than prepend.
173
174
175
        namespace, remaining_args = self.parse_known_args(args=args)
        outputs = []
        for dtype in self.dataclass_types:
176
            keys = {f.name for f in dataclasses.fields(dtype) if f.init}
177
178
179
180
181
182
183
184
185
186
187
            inputs = {k: v for k, v in vars(namespace).items() if k in keys}
            for k in keys:
                delattr(namespace, k)
            obj = dtype(**inputs)
            outputs.append(obj)
        if len(namespace.__dict__) > 0:
            # additional namespace.
            outputs.append(namespace)
        if return_remaining_strings:
            return (*outputs, remaining_args)
        else:
188
189
190
            if remaining_args:
                raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")

191
            return (*outputs,)
192
193
194

    def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
195
196
        Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
        dataclass types.
197
198
199
200
        """
        data = json.loads(Path(json_file).read_text())
        outputs = []
        for dtype in self.dataclass_types:
201
            keys = {f.name for f in dataclasses.fields(dtype) if f.init}
202
203
204
205
            inputs = {k: v for k, v in data.items() if k in keys}
            obj = dtype(**inputs)
            outputs.append(obj)
        return (*outputs,)
206
207
208

    def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
209
210
        Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
        types.
211
212
213
        """
        outputs = []
        for dtype in self.dataclass_types:
214
            keys = {f.name for f in dataclasses.fields(dtype) if f.init}
215
216
217
218
            inputs = {k: v for k, v in args.items() if k in keys}
            obj = dtype(**inputs)
            outputs.append(obj)
        return (*outputs,)