info.py 8.67 KB
Newer Older
mibaumgartner's avatar
utils  
mibaumgartner committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
18
import sys
mibaumgartner's avatar
utils  
mibaumgartner committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import copy
import pathlib
import warnings
import functools

from collections.abc import MutableMapping
from subprocess import PIPE, run
from omegaconf.omegaconf import OmegaConf

from tqdm import tqdm
from typing import Mapping, Sequence, Union, Callable, Any, Iterable
from loguru import logger
from contextlib import contextmanager
from typing import Union, Optional
from pathlib import Path
from git import Repo, InvalidGitRepositoryError

36
37
import functools
import inspect
mibaumgartner's avatar
utils  
mibaumgartner committed
38

39
40
41
42
43
44
45
46
47
48
class SuppressPrint:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout


49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
def deprecate(
    replacement: Optional[str] = None,
    deprecate: Optional[str] = None,
    remove: Optional[str] = None,
    ):
    """
    Deprecate functions and classes

    Args:
        replacement: Optional replacement of old element. if No
            replacement is provided (None) this will expect that the function
            will be removed completely.
        deprecate: Optional version from when element is deprecated.
        remove: Optional version from when element will be removed.
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            if inspect.isclass(func):
                func_name = func.__class__.__name__
            else:
                func_name = func.__name__

            time_str = "now" if deprecate is None else deprecate

            s = f"{func_name} is deprecated from {time_str}!"

            if remove is not None:
                s += f" It will be removed from nnDetection from {remove}"
            if replacement is not None:
                s += f" The replacement is {replacement}."
            else:
                s += f" There will be no replacement."

            logger.warning(s)
            return func(*args, **kwargs)
        return wrapper
    return decorator


def experimental(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if inspect.isclass(func):
            func_name = func.__class__.__name__
        else:
            func_name = func.__qualname__
        
        logger.warning(f"This feature ({func_name}) is experimental! "
                       "It might not implement all features or is only a simplification!")
        return func(*args, **kwargs)
    return wrapper


mibaumgartner's avatar
utils  
mibaumgartner committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
def get_requirements():
    """
    Get all installed packages from currently active environment

    Returns:
        str: list with all requirements
    """
    command = ['pip', 'list']
    result = run(command, stdout=PIPE, stderr=PIPE, universal_newlines=True)
    assert not result.stderr, "stderr not empty"
    return result.stdout


def write_requirements_to_file(path: Union[str, Path]) -> None:
    """
    Write all installed packages from currently active environment to file

    Args:
        path (str): path to file (including file name and extension)
    """
    with open(path, "w+") as f:
        f.write(get_requirements())


def get_repo_info(path: Union[str, Path]):
    """
    Parse repository information from path

    Args:
        path (str): path to repo. If path is not a repository it
        searches parent folders for a repository

    Returns:
        dict: contains the current hash, gitdir and active branch
    """
    def find_repo(findpath):
        p = Path(findpath).absolute()
        for p in [p, *p.parents]:
            try:
                repo = Repo(p)
                break
            except InvalidGitRepositoryError:
                pass
        else:
            raise InvalidGitRepositoryError
        return repo
    repo = find_repo(path)
    return {"hash": repo.head.commit.hexsha,
            "gitdir": repo.git_dir,
            "active_branch": repo.active_branch.name}


def maybe_verbose_iterable(data: Iterable, **kwargs) -> Iterable:
    """
    If verbose flag of nndet is enabled, uses tqdm to create a 
    progress bar

    Args:
        data: iterable to wrap
        **kwargs: keyword arguments passed to tqdm

    Returns:
        Iterable: maybe iterable with progress bar atteched to it
    """
    if bool(int(os.getenv("det_verbose", 1))):
        return tqdm(data, **kwargs)
    else:
        return data


def find_name(tdir: Union[str, Path], name: str,
              postfix: Optional[str] = None) -> Path:
    """
    Generates non exisitng names for files and dirs by adding a counter to
    the end

    Args:
        tdir: target directory where name should be determined for
        name: base name for string
        postfix: postfix for name+counter. Defaults to None.

    Raises:
        RuntimeError: this function only works up to the counter of 1000

    Returns:
        Path: path to generated item
    """
    if not isinstance(tdir, Path):
        tdir = Path(tdir)
    if not tdir.is_dir():
        tdir.mkdir(parents=True)
    if postfix is None:
        postfix = ""

    i=0
    while True:
        output_dir = tdir / f"{name}{i:03d}{postfix}"
        if not output_dir.exists():
            break
        if i > 1000:
            raise RuntimeError(f"Was not able to find name for tdir {tdir} and {name}")
        i += 1
    return output_dir


def log_git(repo_path: Union[pathlib.Path, str], repo_name: str = None):
    """
    Use python logging module to log git information

    Args:
        repo_path (Union[pathlib.Path, str]): path to repo or file inside repository (repository is recursively searched)
    """
    try:
        git_info = get_repo_info(repo_path)
        return git_info
    except Exception:
        logger.error("Was not able to read git information, trying to continue without.")
        return {}


def get_cls_name(obj: Any, package_name: bool = True) -> str:
    """
    Get name of class from object

    Args:
        obj (Any): any object
        package_name (bool): append package origin at the beginning

    Returns:
        str: name of class
    """
    cls_name = str(obj.__class__)
    # remove class prefix
    cls_name = cls_name.split('\'')[1]
    # split modules
    cls_split = cls_name.split('.')
    if len(cls_split) > 1:
        cls_name = cls_split[0] + '.' + cls_split[-1] if package_name else cls_split[-1]
    else:
        cls_name = cls_split[0]
    return cls_name


def log_error(fn: Callable) -> Any:
    """
    Log error messages in hydra log when they occur

    Args:
        fn: function to wrap

    Returns:
        Any
    """
    def wrapper(*args, **kwargs):
        try:
            return fn(*args, **kwargs)
        except Exception as e:
            logger.error(str(e))
            raise e
    return wrapper


@contextmanager
def file_logger(path: Union[str, Path], level: str = "DEBUG", overwrite: bool = True):
    """
    context manager to automatically clean up file logger
    
    Args:
        path: path to output file
        level: logging level. Defaults to "Debug".
    
    Yields:
        None
    """
    path = Path(path)
    if overwrite and path.is_file():
        os.remove(path)
    logger_id = logger.add(path, level=level)
    try:
        yield None
    finally:
        logger.remove(logger_id)


def create_debug_plan(plan: dict) -> str:
    _plan = copy.deepcopy(plan)
    _plan.pop("dataset_properties", None)
    _plan.pop("original_spacings", None)
    _plan.pop("original_sizes", None)
    return stringify_nested_dict(_plan)


def stringify_nested_dict(data: dict):
    if isinstance(data, dict):
        return {str(key): stringify_nested_dict(item) for key, item in data.items()}
    elif isinstance(data, (list, tuple)):
        return [stringify_nested_dict(item) for item in data]
    else:
        return str(data)


def flatten_mapping(
    nested_mapping: Mapping,
    sep: str = ".",
    ) -> Mapping[str, Any]:
    _mapping = {}
    for key, item in nested_mapping.items():
        if isinstance(item, MutableMapping):
            for _key, _item in flatten_mapping(item, sep=sep).items():
                _mapping[str(key) + sep + str(_key)] = _item
        else:
            _mapping[str(key)] = item
    return _mapping