utils.py 5.11 KB
Newer Older
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

4
import inspect
5
import warnings
6
from collections import defaultdict
7
from typing import Any, List, Dict
8
from pathlib import Path
9

QuanluZhang's avatar
QuanluZhang committed
10

11
def import_(target: str, allow_none: bool = False) -> Any:
12
13
14
15
16
    if target is None:
        return None
    path, identifier = target.rsplit('.', 1)
    module = __import__(path, globals(), locals(), [identifier])
    return getattr(module, identifier)
QuanluZhang's avatar
QuanluZhang committed
17

18

19
20
21
22
23
def version_larger_equal(a: str, b: str) -> bool:
    # TODO: refactor later
    a = a.split('+')[0]
    b = b.split('+')[0]
    return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.')))
24

25

26
27
_last_uid = defaultdict(int)

28
29
_DEFAULT_MODEL_NAMESPACE = 'model'

30

31
32
33
def uid(namespace: str = 'default') -> int:
    _last_uid[namespace] += 1
    return _last_uid[namespace]
34
35


36
37
38
39
def reset_uid(namespace: str = 'default') -> None:
    _last_uid[namespace] = 0


40
41
def get_module_name(cls_or_func):
    module_name = cls_or_func.__module__
42
43
44
45
46
47
    if module_name == '__main__':
        # infer the module name with inspect
        for frm in inspect.stack():
            if inspect.getmodule(frm[0]).__name__ == '__main__':
                # main module found
                main_file_path = Path(inspect.getsourcefile(frm[0]))
liuzhe-lz's avatar
liuzhe-lz committed
48
                if not Path().samefile(main_file_path.parent):
49
                    raise RuntimeError(f'You are using "{main_file_path}" to launch your experiment, '
50
                                       f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
51
52
                module_name = main_file_path.stem
                break
53
54
55
    if module_name == '__main__':
        warnings.warn('Callstack exhausted but main module still not found. This will probably cause issues that the '
                      'function/class cannot be imported.')
56
57
58
59
60

    # NOTE: this is hacky. As torchscript retrieves LSTM's source code to do something.
    # to make LSTM's source code can be found, we should assign original LSTM's __module__ to
    # the wrapped LSTM's __module__
    # TODO: find out all the modules that have the same requirement as LSTM
61
62
    if f'{cls_or_func.__module__}.{cls_or_func.__name__}' == 'torch.nn.modules.rnn.LSTM':
        module_name = cls_or_func.__module__
63

64
65
66
    return module_name


67
def get_importable_name(cls, relocate_module=False):
68
    module_name = get_module_name(cls) if relocate_module else cls.__module__
69
    return module_name + '.' + cls.__name__
70
71


Yuge Zhang's avatar
Yuge Zhang committed
72
73
74
75
class NoContextError(Exception):
    pass


76
77
78
79
80
81
class ContextStack:
    """
    This is to maintain a globally-accessible context envinronment that is visible to everywhere.

    Use ``with ContextStack(namespace, value):`` to initiate, and use ``get_current_context(namespace)`` to
    get the corresponding value in the namespace.
82
83

    Note that this is not multi-processing safe. Also, the values will get cleared for a new process.
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    """

    _stack: Dict[str, List[Any]] = defaultdict(list)

    def __init__(self, key: str, value: Any):
        self.key = key
        self.value = value

    def __enter__(self):
        self.push(self.key, self.value)
        return self

    def __exit__(self, *args, **kwargs):
        self.pop(self.key)

    @classmethod
    def push(cls, key: str, value: Any):
        cls._stack[key].append(value)

    @classmethod
    def pop(cls, key: str) -> None:
        cls._stack[key].pop()

    @classmethod
    def top(cls, key: str) -> Any:
Yuge Zhang's avatar
Yuge Zhang committed
109
110
        if not cls._stack[key]:
            raise NoContextError('Context is empty.')
111
112
113
        return cls._stack[key][-1]


114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class ModelNamespace:
    """
    To create an individual namespace for models to enable automatic numbering.
    """

    def __init__(self, key: str = _DEFAULT_MODEL_NAMESPACE):
        # for example, key: "model_wrapper"
        self.key = key

    def __enter__(self):
        # For example, currently the top of stack is [1, 2, 2], and [1, 2, 2, 3] is used,
        # the next thing up is [1, 2, 2, 4].
        # `reset_uid` to count from zero for "model_wrapper_1_2_2_4"
        try:
            current_context = ContextStack.top(self.key)
            next_uid = uid(self._simple_name(self.key, current_context))
            ContextStack.push(self.key, current_context + [next_uid])
            reset_uid(self._simple_name(self.key, current_context + [next_uid]))
        except NoContextError:
            ContextStack.push(self.key, [])
            reset_uid(self._simple_name(self.key, []))

    def __exit__(self, *args, **kwargs):
        ContextStack.pop(self.key)

    @staticmethod
    def next_label(key: str = _DEFAULT_MODEL_NAMESPACE) -> str:
        try:
            current_context = ContextStack.top(key)
        except NoContextError:
            # fallback to use "default" namespace
            return ModelNamespace._simple_name('default', [uid()])

        next_uid = uid(ModelNamespace._simple_name(key, current_context))
        return ModelNamespace._simple_name(key, current_context + [next_uid])

    @staticmethod
    def _simple_name(key: str, lst: List[Any]) -> str:
        return key + ''.join(['_' + str(k) for k in lst])


155
156
def get_current_context(key: str) -> Any:
    return ContextStack.top(key)