instantiate.py 6.2 KB
Newer Older
yuguo960516's avatar
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import logging
from collections import abc
from enum import Enum
from typing import Any, Callable, Dict, List, Union

from hydra.errors import InstantiationException
from omegaconf import OmegaConf

from libai.config.lazy import _convert_target_to_string, locate

logger = logging.getLogger(__name__)

__all__ = ["dump_dataclass", "instantiate"]

# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/config/instantiate.py
# --------------------------------------------------------


class _Keys(str, Enum):
    """Special keys in configs used by instantiate."""

    TARGET = "_target_"
    RECURSIVE = "_recursive_"


def _is_target(x: Any) -> bool:
    if isinstance(x, dict):
        return _Keys.TARGET in x
    if OmegaConf.is_dict(x):
        return _Keys.TARGET in x
    return False


def _is_dict(cfg: Any) -> bool:
    return OmegaConf.is_dict(cfg) or isinstance(cfg, abc.Mapping)


def _is_list(cfg: Any) -> bool:
    return OmegaConf.is_list(cfg) or isinstance(cfg, list)


def dump_dataclass(obj: Any):
    """
    Dump a dataclass recursively into a dict that can be later instantiated.

    Args:
        obj: a dataclass object

    Returns:
        dict
    """
    assert dataclasses.is_dataclass(obj) and not isinstance(
        obj, type
    ), "dump_dataclass() requires an instance of a dataclass."
    ret = {"_target_": _convert_target_to_string(type(obj))}
    for f in dataclasses.fields(obj):
        v = getattr(obj, f.name)
        if dataclasses.is_dataclass(v):
            v = dump_dataclass(v)
        if isinstance(v, (list, tuple)):
            v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v]
        ret[f.name] = v
    return ret


def _prepare_input_dict_or_list(d: Union[Dict[Any, Any], List[Any]]) -> Any:
    res: Any
    if isinstance(d, dict):
        res = {}
        for k, v in d.items():
            if k == "_target_":
                v = _convert_target_to_string(d["_target_"])
            elif isinstance(v, (dict, list)):
                v = _prepare_input_dict_or_list(v)
            res[k] = v
    elif isinstance(d, list):
        res = []
        for v in d:
            if isinstance(v, (list, dict)):
                v = _prepare_input_dict_or_list(v)
            res.append(v)
    else:
        assert False
    return res


def _resolve_target(target):
    if isinstance(target, str):
        try:
            target = locate(target)
        except Exception as e:
            msg = f"Error locating target '{target}', see chained exception above."
            raise InstantiationException(msg) from e

    if not callable(target):
        msg = f"Expected a callable target, got '{target}' of type '{type(target).__name__}'"
        raise InstantiationException(msg)
    return target


def _call_target(_target_: Callable[..., Any], kwargs: Dict[str, Any]):
    """Call target (type) with kwargs"""

    try:
        return _target_(**kwargs)
    except Exception as e:
        msg = f"Error in call to target '{_convert_target_to_string(_target_)}':\n{repr(e)}"
        raise InstantiationException(msg) from e


def instantiate(cfg, **kwargs: Any) -> Any:
    """
    Recursively instantiate objects defined in dictionaries by
    "_target_" and arguments.

    Args:
        cfg: a dict-like object with "_target_" that defines the caller, and
            other keys that define the arguments

    Returns:
        object instantiated by cfg
    """
    if cfg is None:
        return None

    if isinstance(cfg, (dict, list)):
        cfg = _prepare_input_dict_or_list(cfg)

    kwargs = _prepare_input_dict_or_list(kwargs)

    if _is_dict(cfg):
        if kwargs:
            cfg = OmegaConf.merge(cfg, kwargs)

        _recursive_ = kwargs.pop(_Keys.RECURSIVE, True)
        return instantiate_cfg(cfg, recursive=_recursive_)

    elif _is_list(cfg):
        _recursive_ = kwargs.pop(_Keys.RECURSIVE, True)
        return instantiate_cfg(cfg, recursive=_recursive_)
    else:
        return cfg  # return as-is if don't know what to do


def instantiate_cfg(cfg: Any, recursive: bool = True):
    if cfg is None:
        return cfg

    if _is_dict(cfg):
        recursive = cfg[_Keys.RECURSIVE] if _Keys.RECURSIVE in cfg else recursive

    if not isinstance(recursive, bool):
        msg = f"Instantiation: _recursive_ flag must be a bool, got {type(recursive)}"
        raise TypeError(msg)

    # If OmegaConf list, create new list of instances if recursive
    if OmegaConf.is_list(cfg):
        items = [instantiate_cfg(item, recursive=recursive) for item in cfg._iter_ex(resolve=True)]
        lst = OmegaConf.create(items, flags={"allow_objects": True})
        return lst

    elif isinstance(cfg, list):
        # Specialize for list, because many classes take
        # list[objects] as arguments, such as ResNet, DatasetMapper
        return [instantiate(item, recursive=recursive) for item in cfg]

    elif _is_dict(cfg):
        exclude_keys = set({"_target_", "_recursive_"})
        if _is_target(cfg):
            _target_ = instantiate(cfg.get(_Keys.TARGET))  # instantiate lazy target
            _target_ = _resolve_target(_target_)
            kwargs = {}
            for key, value in cfg.items():
                if key not in exclude_keys:
                    if recursive:
                        value = instantiate_cfg(value, recursive=recursive)
                    kwargs[key] = value
            return _call_target(_target_, kwargs)
        else:
            return cfg
    else:
        return cfg  # return as-is if don't know what to do