args.py 5.04 KB
Newer Older
huchen's avatar
huchen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import inspect
import logging
from typing import Any, Callable, Dict, Optional, Union

from .core import GET_ARGPARSER_FN_NAME, load_from_file

LOGGER = logging.getLogger(__name__)


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("Boolean value expected.")


def filter_fn_args(args: Union[dict, argparse.Namespace], fn: Callable) -> dict:
    signature = inspect.signature(fn)
    parameters_names = list(signature.parameters)
    if isinstance(args, argparse.Namespace):
        args = vars(args)
    args = {k: v for k, v in args.items() if k in parameters_names}
    return args


def add_args_for_fn_signature(parser, fn) -> argparse.ArgumentParser:
    parser.conflict_handler = "resolve"
    signature = inspect.signature(fn)
    for parameter in signature.parameters.values():
        if parameter.name in ["self", "args", "kwargs"]:
            continue
        argument_kwargs = {}
        if parameter.annotation != inspect.Parameter.empty:
            if parameter.annotation == bool:
                argument_kwargs["type"] = str2bool
                argument_kwargs["choices"] = [0, 1]
            elif type(parameter.annotation) == type(Union): # isinstance(parameter.annotation, type(Optional[Any])):
                types = [type_ for type_ in parameter.annotation.__args__ if not isinstance(None, type_)]
                if len(types) != 1:
                    raise RuntimeError(
                        f"Could not prepare argument parser for {parameter.name}: {parameter.annotation} in {fn}"
                    )
                argument_kwargs["type"] = types[0]
            else:
                argument_kwargs["type"] = parameter.annotation

        if parameter.default != inspect.Parameter.empty:
            if parameter.annotation == bool:
                argument_kwargs["default"] = str2bool(parameter.default)
            else:
                argument_kwargs["default"] = parameter.default
        else:
            argument_kwargs["required"] = True
        name = parameter.name.replace("_", "-")
        LOGGER.debug(f"Adding argument {name} with {argument_kwargs}")
        parser.add_argument(f"--{name}", **argument_kwargs)
    return parser


class ArgParserGenerator:
    def __init__(self, cls_or_fn, module_path: Optional[str] = None):
        self._cls_or_fn = cls_or_fn

        self._handle = cls_or_fn if inspect.isfunction(cls_or_fn) else getattr(cls_or_fn, "__init__")
        input_is_python_file = module_path and module_path.endswith(".py")
        self._input_path = module_path if input_is_python_file else None
        self._required_fn_name_for_signature_parsing = getattr(
            cls_or_fn, "required_fn_name_for_signature_parsing", None
        )

    def update_argparser(self, parser):
        name = self._handle.__name__
        group_parser = parser.add_argument_group(name)
        add_args_for_fn_signature(group_parser, fn=self._handle)
        self._update_argparser(group_parser)

    def get_args(self, args: argparse.Namespace):
        filtered_args = filter_fn_args(args, fn=self._handle)

        tmp_parser = argparse.ArgumentParser(allow_abbrev=False)
        self._update_argparser(tmp_parser)
        custom_names = [
            p.dest.replace("-", "_") for p in tmp_parser._actions if not isinstance(p, argparse._HelpAction)
        ]
        custom_params = {n: getattr(args, n) for n in custom_names}
        filtered_args = {**filtered_args, **custom_params}
        return filtered_args

    def from_args(self, args: Union[argparse.Namespace, Dict]):
        args = self.get_args(args)
        LOGGER.info(f"Initializing {self._cls_or_fn.__name__}({args})")
        return self._cls_or_fn(**args)

    def _update_argparser(self, parser):
        label = "argparser_update"
        if self._input_path:
            update_argparser_handle = load_from_file(self._input_path, label=label, target=GET_ARGPARSER_FN_NAME)
            if update_argparser_handle:
                update_argparser_handle(parser)
            elif self._required_fn_name_for_signature_parsing:
                fn_handle = load_from_file(
                    self._input_path, label=label, target=self._required_fn_name_for_signature_parsing
                )
                if fn_handle:
                    add_args_for_fn_signature(parser, fn_handle)