lazy_load.py 9.14 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import ast
import dataclasses
import inspect
import logging
import pydoc
from collections import abc
from typing import Any, List

from omegaconf import DictConfig

try:
    from ast import unparse
except ImportError:
    from astunparse import unparse


class Config:
    def __init__(self, file_path, name_space={}, partials=()):
        self.partials = partials
        with open(file_path, "r") as f:
            code = f.read()
        if len(partials) != 0:
            code = self.partial_optim(code)
        exec(code, name_space)
        self.__dict__ = {k: v for k, v in name_space.items() if k != "__builtins__"}

    def partial_optim(self, code):
        tree = ast.parse(code)
        for node in ast.walk(tree):
            if isinstance(node, ast.Assign):
                if type(node.value) == ast.Call:
                    assign_target = unparse(node.targets[0]).rstrip("\n")
                    variant = assign_target.replace("'", '"')
                    if assign_target in self.partials or variant in self.partials:
                        node.value = ast.Call(
                            func=ast.Name(id="partial", ctx=ast.Load()),
                            args=[node.value.func] + node.value.args,
                            keywords=[] + node.value.keywords,
                        )
        ast_string = "from functools import partial\n" + unparse(tree)
        return ast_string


class LazyConfig:
    def __init__(self, file_path, name_space={}, lazy={}):
        self.lazy = lazy
        with open(file_path, "r") as f:
            code = f.read()
        if len(self.lazy) != 0:
            code = self.replace_call_with_lazy_call(code)
        exec(code, name_space)
        self.__dict__ = {k: v for k, v in name_space.items() if k != "__builtins__"}

    def replace_call_with_lazy_call(self, code):
        tree = ast.parse(code)
        for node in ast.walk(tree):
            if isinstance(node, ast.Assign):
                if type(node.value) == ast.Call:
                    assign_target = unparse(node.targets[0]).rstrip("\n")
                    variant = assign_target.replace("'", '"')
                    if assign_target in self.lazy or variant in self.lazy:
                        node.value = ast.Call(
                            func=ast.Call(
                                func=ast.Name(id="L", ctx=ast.Load()),
                                args=[node.value.func],
                                keywords=[],
                            ),
                            args=node.value.args,
                            keywords=node.value.keywords,
                        )
        ast_string = "from util.lazy_load import LazyCall as L\n" + unparse(tree)
        return ast_string


def is_dataclass(obj):
    """Returns True if obj is a dataclass or an instance of a
    dataclass."""
    cls = obj if isinstance(obj, type) and not isinstance(obj, type(List[int])) else type(obj)
    return hasattr(cls, "__dataclass_fields__")


def locate(name: str) -> Any:
    """
    Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``,
    such as "module.submodule.class_name".

    Raise Exception if it cannot be found.
    """
    obj = pydoc.locate(name)

    # Some cases (e.g. torch.optim.sgd.SGD) not handled correctly
    # by pydoc.locate. Try a private function from hydra.
    if obj is None:
        try:
            # from hydra.utils import get_method - will print many errors
            from hydra.utils import _locate
        except ImportError as e:
            raise ImportError(f"Cannot dynamically locate object {name}!") from e
        else:
            obj = _locate(name)  # it raises if fails

    return obj


def _convert_target_to_string(t: Any) -> str:
    """
    Inverse of ``locate()``.

    Args:
        t: any object with ``__module__`` and ``__qualname__``
    """
    module, qualname = t.__module__, t.__qualname__

    # Compress the path to this object, e.g. ``module.submodule._impl.class``
    # may become ``module.submodule.class``, if the later also resolves to the same
    # object. This simplifies the string, and also is less affected by moving the
    # class implementation.
    module_parts = module.split(".")
    for k in range(1, len(module_parts)):
        prefix = ".".join(module_parts[:k])
        candidate = f"{prefix}.{qualname}"
        try:
            if locate(candidate) is t:
                return candidate
        except ImportError:
            pass
    return f"{module}.{qualname}"


class LazyCall:
    """
    Wrap a callable so that when it's called, the call will not be executed,
    but returns a dict that describes the call.
    LazyCall object has to be called with only keyword arguments. Positional
    arguments are not yet supported.
    Examples:
    ::
        from detectron2.config import instantiate, LazyCall
        layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32)
        layer_cfg.out_channels = 64   # can edit it afterwards
        layer = instantiate(layer_cfg)
    """
    def __init__(self, target):
        if not (callable(target) or isinstance(target, (str, abc.Mapping))):
            raise TypeError(f"target of LazyCall must be a callable or defines a callable! Got {target}")
        self._target = target

    def __call__(self, *args, **kwargs):
        if is_dataclass(self._target):
            # omegaconf object cannot hold dataclass type
            # https://github.com/omry/omegaconf/issues/784
            target = _convert_target_to_string(self._target)
        else:
            target = self._target
        variable_args, arg_kwargs = self.transfer_args_into_kwargs(args)

        kwargs.update(arg_kwargs)
        kwargs["_target_"] = target
        kwargs["_variable_args_"] = variable_args

        return DictConfig(content=kwargs, flags={"allow_objects": True})

    def transfer_args_into_kwargs(self, args):
        kwargs = {}
        variable_args = None
        params = inspect.signature(self._target).parameters
        for arg_ind, (name, param) in enumerate(params.items()):
            if arg_ind >= len(args):
                break
            if param.kind == inspect._ParameterKind.VAR_POSITIONAL:
                variable_args = args[arg_ind:]
                break
            else:
                kwargs[name] = args[arg_ind]
        return variable_args, kwargs


def instantiate(cfg):
    """
    Recursively instantiate objects defined in dictionaries by
    "_target_" and arguments.

    Args:
        cfg: a dict-like object with "_target_" that defines the caller, and
            other keys that define the arguments

    Returns:
        object instantiated by cfg
    """
    from omegaconf import DictConfig, ListConfig, OmegaConf

    if isinstance(cfg, ListConfig):
        lst = [instantiate(x) for x in cfg]
        return ListConfig(lst, flags={"allow_objects": True})
    if isinstance(cfg, list):
        # Specialize for list, because many classes take
        # list[objects] as arguments, such as ResNet, DatasetMapper
        return [instantiate(x) for x in cfg]

    # If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config),
    # instantiate it to the actual dataclass.
    if isinstance(cfg, DictConfig) and dataclasses.is_dataclass(cfg._metadata.object_type):
        return OmegaConf.to_object(cfg)

    if isinstance(cfg, abc.Mapping) and "_target_" in cfg:
        # conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all,
        # but faster: https://github.com/facebookresearch/hydra/issues/1200
        cfg = {k: instantiate(v) for k, v in cfg.items()}
        cls = cfg.pop("_target_")
        variable_args = cfg.pop("_variable_args_")
        cls = instantiate(cls)

        if isinstance(cls, str):
            cls_name = cls
            cls = locate(cls_name)
            assert cls is not None, cls_name
        else:
            try:
                cls_name = cls.__module__ + "." + cls.__qualname__
            except Exception:
                # target could be anything, so the above could fail
                cls_name = str(cls)
        assert callable(cls), f"_target_ {cls} does not define a callable object"
        try:
            # split args from kwargs and instantiate cls with normal sequence:
            # args, variable_args, kwargs
            if variable_args is not None:
                params = inspect.signature(cls).parameters
                try:
                    p_kind_list = [p.kind for p in params.values()]
                    i = p_kind_list.index(inspect._ParameterKind.VAR_POSITIONAL)
                except ValueError:
                    i = None
                arg_keys = list(params.keys())[:i]
                args = []
                for key in arg_keys:
                    args.append(cfg.pop(key))
                if variable_args is not None:
                    args.extend(variable_args)
                return cls(*args, **cfg)
            else:
                return cls(**cfg)
        except TypeError:
            import os

            logger = logging.getLogger(os.path.basename(os.getcwd()) + "." + __name__)
            logger.error(f"Error when instantiating {cls_name}!")
            raise
    return cfg  # return as-is if don't know what to do