util.py 3.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
Miscellaneous utility functions.
"""

import math
import os.path
from pathlib import Path
11
from typing import Any, Dict, Optional, Union, List
12
13
14

PathLike = Union[Path, str]

liuzhe-lz's avatar
liuzhe-lz committed
15
16
17
18
19
def case_insensitive(key_or_kwargs: Union[str, Dict[str, Any]]) -> Union[str, Dict[str, Any]]:
    if isinstance(key_or_kwargs, str):
        return key_or_kwargs.lower().replace('_', '')
    else:
        return {key.lower().replace('_', ''): value for key, value in key_or_kwargs.items()}
20
21
22
23
24
25
26
27
28
29
30
31

def camel_case(key: str) -> str:
    words = key.split('_')
    return words[0] + ''.join(word.title() for word in words[1:])

def canonical_path(path: Optional[PathLike]) -> Optional[str]:
    # Path.resolve() does not work on Windows when file not exist, so use os.path instead
    return os.path.abspath(os.path.expanduser(path)) if path is not None else None

def count(*values) -> int:
    return sum(value is not None and value is not False for value in values)

32
def training_service_config_factory(platform: Union[str, List[str]] = None, config: Union[List, Dict] = None): # -> TrainingServiceConfig
33
    from .common import TrainingServiceConfig
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    ts_configs = []
    if platform is not None:
        assert config is None
        platforms = platform if isinstance(platform, list) else [platform]
        for cls in TrainingServiceConfig.__subclasses__():
            if cls.platform in platforms:
                ts_configs.append(cls())
        if len(ts_configs) < len(platforms):
            raise RuntimeError('There is unrecognized platform!')
    else:
        assert config is not None
        supported_platforms = {cls.platform: cls for cls in TrainingServiceConfig.__subclasses__()}
        configs = config if isinstance(config, list) else [config]
        for conf in configs:
            if conf['platform'] not in supported_platforms:
                raise RuntimeError(f'Unrecognized platform {conf["platform"]}')
            ts_configs.append(supported_platforms[conf['platform']](**conf))
    return ts_configs if len(ts_configs) > 1 else ts_configs[0]
52

liuzhe-lz's avatar
liuzhe-lz committed
53
54
55
56
57
58
59
def load_config(Type, value):
    if isinstance(value, list):
        return [load_config(Type, item) for item in value]
    if isinstance(value, dict):
        return Type(**value)
    return value

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def strip_optional(type_hint):
    return type_hint.__args__[0] if str(type_hint).startswith('typing.Optional[') else type_hint

def parse_time(time: str, target_unit: str = 's') -> int:
    return _parse_unit(time.lower(), target_unit, _time_units)

def parse_size(size: str, target_unit: str = 'mb') -> int:
    return _parse_unit(size.lower(), target_unit, _size_units)

_time_units = {'d': 24 * 3600, 'h': 3600, 'm': 60, 's': 1}
_size_units = {'gb': 1024 * 1024 * 1024, 'mb': 1024 * 1024, 'kb': 1024}

def _parse_unit(string, target_unit, all_units):
    for unit, factor in all_units.items():
        if string.endswith(unit):
            number = string[:-len(unit)]
            value = float(number) * factor
            return math.ceil(value / all_units[target_unit])
    raise ValueError(f'Unsupported unit in "{string}"')