hf_argparser.py 7.91 KB
Newer Older
1
import dataclasses
2
3
import json
import sys
4
5
from argparse import ArgumentParser
from enum import Enum
6
from pathlib import Path
Sylvain Gugger's avatar
Sylvain Gugger committed
7
from typing import Any, Iterable, List, NewType, Optional, Tuple, Union
8
9
10
11
12
13
14
15


DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)


class HfArgumentParser(ArgumentParser):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
16
    This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
17

Sylvain Gugger's avatar
Sylvain Gugger committed
18
19
20
    The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed)
    arguments to the parser after initialization and you'll get the output back after parsing as an additional
    namespace.
21
22
23
24
25
26
27
28
    """

    dataclass_types: Iterable[DataClassType]

    def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
        """
        Args:
            dataclass_types:
Sylvain Gugger's avatar
Sylvain Gugger committed
29
                Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
            kwargs:
                (Optional) Passed to `argparse.ArgumentParser()` in the regular way.
        """
        super().__init__(**kwargs)
        if dataclasses.is_dataclass(dataclass_types):
            dataclass_types = [dataclass_types]
        self.dataclass_types = dataclass_types
        for dtype in self.dataclass_types:
            self._add_dataclass_arguments(dtype)

    def _add_dataclass_arguments(self, dtype: DataClassType):
        for field in dataclasses.fields(dtype):
            field_name = f"--{field.name}"
            kwargs = field.metadata.copy()
            # field.metadata is not used at all by Data Classes,
            # it is provided as a third-party extension mechanism.
            if isinstance(field.type, str):
                raise ImportError(
                    "This implementation is not compatible with Postponed Evaluation of Annotations (PEP 563),"
                    "which can be opted in from Python 3.7 with `from __future__ import annotations`."
                    "We will add compatibility when Python 3.9 is released."
                )
            typestring = str(field.type)
53
54
55
56
57
58
59
            for prim_type in (int, float, str):
                for collection in (List,):
                    if typestring == f"typing.Union[{collection[prim_type]}, NoneType]":
                        field.type = collection[prim_type]
                if typestring == f"typing.Union[{prim_type.__name__}, NoneType]":
                    field.type = prim_type

60
61
62
63
64
            if isinstance(field.type, type) and issubclass(field.type, Enum):
                kwargs["choices"] = list(field.type)
                kwargs["type"] = field.type
                if field.default is not dataclasses.MISSING:
                    kwargs["default"] = field.default
Sylvain Gugger's avatar
Sylvain Gugger committed
65
            elif field.type is bool or field.type is Optional[bool]:
66
67
                if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
                    kwargs["action"] = "store_false" if field.default is True else "store_true"
68
                if field.default is True:
69
                    field_name = f"--no_{field.name}"
70
                    kwargs["dest"] = field.name
71
72
73
74
75
76
77
78
            elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List):
                kwargs["nargs"] = "+"
                kwargs["type"] = field.type.__args__[0]
                assert all(
                    x == kwargs["type"] for x in field.type.__args__
                ), "{} cannot be a List of mixed types".format(field.name)
                if field.default_factory is not dataclasses.MISSING:
                    kwargs["default"] = field.default_factory()
79
80
81
82
            else:
                kwargs["type"] = field.type
                if field.default is not dataclasses.MISSING:
                    kwargs["default"] = field.default
Julien Plu's avatar
Julien Plu committed
83
84
                elif field.default_factory is not dataclasses.MISSING:
                    kwargs["default"] = field.default_factory()
85
86
87
88
                else:
                    kwargs["required"] = True
            self.add_argument(field_name, **kwargs)

89
    def parse_args_into_dataclasses(
90
        self, args=None, return_remaining_strings=False, look_for_args_file=True, args_filename=None
91
    ) -> Tuple[DataClass, ...]:
92
93
94
        """
        Parse command-line args into instances of the specified dataclass types.

Sylvain Gugger's avatar
Sylvain Gugger committed
95
        This relies on argparse's `ArgumentParser.parse_known_args`. See the doc at:
96
97
98
99
        docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args

        Args:
            args:
Sylvain Gugger's avatar
Sylvain Gugger committed
100
                List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser)
101
102
            return_remaining_strings:
                If true, also return a list of remaining argument strings.
103
            look_for_args_file:
Sylvain Gugger's avatar
Sylvain Gugger committed
104
105
                If true, will look for a ".args" file with the same base name as the entry point script for this
                process, and will append its potential content to the command line args.
106
            args_filename:
Sylvain Gugger's avatar
Sylvain Gugger committed
107
                If not None, will uses this file instead of the ".args" file specified in the previous argument.
108
109
110

        Returns:
            Tuple consisting of:
Sylvain Gugger's avatar
Sylvain Gugger committed
111
112
113

                - the dataclass instances in the same order as they were passed to the initializer.abspath
                - if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser
114
                  after initialization.
Sylvain Gugger's avatar
Sylvain Gugger committed
115
                - The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args)
116
        """
117
118
119
120
121
122
        if args_filename or (look_for_args_file and len(sys.argv)):
            if args_filename:
                args_file = Path(args_filename)
            else:
                args_file = Path(sys.argv[0]).with_suffix(".args")

123
124
125
126
127
            if args_file.exists():
                fargs = args_file.read_text().split()
                args = fargs + args if args is not None else fargs + sys.argv[1:]
                # in case of duplicate arguments the first one has precedence
                # so we append rather than prepend.
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        namespace, remaining_args = self.parse_known_args(args=args)
        outputs = []
        for dtype in self.dataclass_types:
            keys = {f.name for f in dataclasses.fields(dtype)}
            inputs = {k: v for k, v in vars(namespace).items() if k in keys}
            for k in keys:
                delattr(namespace, k)
            obj = dtype(**inputs)
            outputs.append(obj)
        if len(namespace.__dict__) > 0:
            # additional namespace.
            outputs.append(namespace)
        if return_remaining_strings:
            return (*outputs, remaining_args)
        else:
143
144
145
            if remaining_args:
                raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")

146
            return (*outputs,)
147
148
149

    def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
150
151
        Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
        dataclass types.
152
153
154
155
156
157
158
159
160
        """
        data = json.loads(Path(json_file).read_text())
        outputs = []
        for dtype in self.dataclass_types:
            keys = {f.name for f in dataclasses.fields(dtype)}
            inputs = {k: v for k, v in data.items() if k in keys}
            obj = dtype(**inputs)
            outputs.append(obj)
        return (*outputs,)
161
162
163

    def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
164
165
        Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
        types.
166
167
168
169
170
171
172
173
        """
        outputs = []
        for dtype in self.dataclass_types:
            keys = {f.name for f in dataclasses.fields(dtype)}
            inputs = {k: v for k, v in args.items() if k in keys}
            obj = dtype(**inputs)
            outputs.append(obj)
        return (*outputs,)