"loaders/create_training_dataset.py" did not exist on "7aa1ab82c66d3bcb93fbe23462e3fb088d29d247"
util.py 2.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
Miscellaneous utility functions.
"""

import math
import os.path
from pathlib import Path
liuzhe-lz's avatar
liuzhe-lz committed
11
from typing import Any, Dict, Optional, Union
12
13
14

PathLike = Union[Path, str]

liuzhe-lz's avatar
liuzhe-lz committed
15
16
17
18
19
def case_insensitive(key_or_kwargs: Union[str, Dict[str, Any]]) -> Union[str, Dict[str, Any]]:
    if isinstance(key_or_kwargs, str):
        return key_or_kwargs.lower().replace('_', '')
    else:
        return {key.lower().replace('_', ''): value for key, value in key_or_kwargs.items()}
20
21
22
23
24
25
26
27
28
29
30
31

def camel_case(key: str) -> str:
    words = key.split('_')
    return words[0] + ''.join(word.title() for word in words[1:])

def canonical_path(path: Optional[PathLike]) -> Optional[str]:
    # Path.resolve() does not work on Windows when file not exist, so use os.path instead
    return os.path.abspath(os.path.expanduser(path)) if path is not None else None

def count(*values) -> int:
    return sum(value is not None and value is not False for value in values)

liuzhe-lz's avatar
liuzhe-lz committed
32
def training_service_config_factory(platform: str, **kwargs): # -> TrainingServiceConfig
33
34
35
    from .common import TrainingServiceConfig
    for cls in TrainingServiceConfig.__subclasses__():
        if cls.platform == platform:
liuzhe-lz's avatar
liuzhe-lz committed
36
            return cls(**kwargs)
37
38
    raise ValueError(f'Unrecognized platform {platform}')

liuzhe-lz's avatar
liuzhe-lz committed
39
40
41
42
43
44
45
def load_config(Type, value):
    if isinstance(value, list):
        return [load_config(Type, item) for item in value]
    if isinstance(value, dict):
        return Type(**value)
    return value

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def strip_optional(type_hint):
    return type_hint.__args__[0] if str(type_hint).startswith('typing.Optional[') else type_hint

def parse_time(time: str, target_unit: str = 's') -> int:
    return _parse_unit(time.lower(), target_unit, _time_units)

def parse_size(size: str, target_unit: str = 'mb') -> int:
    return _parse_unit(size.lower(), target_unit, _size_units)

_time_units = {'d': 24 * 3600, 'h': 3600, 'm': 60, 's': 1}
_size_units = {'gb': 1024 * 1024 * 1024, 'mb': 1024 * 1024, 'kb': 1024}

def _parse_unit(string, target_unit, all_units):
    for unit, factor in all_units.items():
        if string.endswith(unit):
            number = string[:-len(unit)]
            value = float(number) * factor
            return math.ceil(value / all_units[target_unit])
    raise ValueError(f'Unsupported unit in "{string}"')