launcher_utils.py 6.2 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
Deshui Yu's avatar
Deshui Yu committed
3
4

import os
chicm-ms's avatar
chicm-ms committed
5
from schema import SchemaError
chicm-ms's avatar
chicm-ms committed
6
7
from .config_schema import NNIConfigSchema
from .common_utils import print_normal
8
9
10
11
12
13
14
15
16

def expand_path(experiment_config, key):
    '''Change '~' to user home directory'''
    if experiment_config.get(key):
        experiment_config[key] = os.path.expanduser(experiment_config[key])

def parse_relative_path(root_path, experiment_config, key):
    '''Change relative path to absolute path'''
    if experiment_config.get(key) and not os.path.isabs(experiment_config.get(key)):
SparkSnail's avatar
SparkSnail committed
17
        absolute_path = os.path.join(root_path, experiment_config.get(key))
SparkSnail's avatar
SparkSnail committed
18
        print_normal('expand %s: %s to %s ' % (key, experiment_config[key], absolute_path))
SparkSnail's avatar
SparkSnail committed
19
20
        experiment_config[key] = absolute_path

21
22
23
def parse_time(time):
    '''Change the time to seconds'''
    unit = time[-1]
Deshui Yu's avatar
Deshui Yu committed
24
    if unit not in ['s', 'm', 'h', 'd']:
chicm-ms's avatar
chicm-ms committed
25
        raise SchemaError('the unit of time could only from {s, m, h, d}')
26
    time = time[:-1]
Deshui Yu's avatar
Deshui Yu committed
27
    if not time.isdigit():
chicm-ms's avatar
chicm-ms committed
28
        raise SchemaError('time format error!')
Deshui Yu's avatar
Deshui Yu committed
29
    parse_dict = {'s':1, 'm':60, 'h':3600, 'd':86400}
30
    return int(time) * parse_dict[unit]
Deshui Yu's avatar
Deshui Yu committed
31

32
33
34
35
36
def parse_path(experiment_config, config_path):
    '''Parse path in config file'''
    expand_path(experiment_config, 'searchSpacePath')
    if experiment_config.get('trial'):
        expand_path(experiment_config['trial'], 'codeDir')
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        if experiment_config['trial'].get('authFile'):
            expand_path(experiment_config['trial'], 'authFile')
        if experiment_config['trial'].get('ps'):
            if experiment_config['trial']['ps'].get('privateRegistryAuthPath'):
                expand_path(experiment_config['trial']['ps'], 'privateRegistryAuthPath')
        if experiment_config['trial'].get('master'):
            if experiment_config['trial']['master'].get('privateRegistryAuthPath'):
                expand_path(experiment_config['trial']['master'], 'privateRegistryAuthPath')
        if experiment_config['trial'].get('worker'):
            if experiment_config['trial']['worker'].get('privateRegistryAuthPath'):
                expand_path(experiment_config['trial']['worker'], 'privateRegistryAuthPath')
        if experiment_config['trial'].get('taskRoles'):
            for index in range(len(experiment_config['trial']['taskRoles'])):
                if experiment_config['trial']['taskRoles'][index].get('privateRegistryAuthPath'):
                    expand_path(experiment_config['trial']['taskRoles'][index], 'privateRegistryAuthPath')
52
53
54
55
    if experiment_config.get('tuner'):
        expand_path(experiment_config['tuner'], 'codeDir')
    if experiment_config.get('assessor'):
        expand_path(experiment_config['assessor'], 'codeDir')
QuanluZhang's avatar
QuanluZhang committed
56
57
    if experiment_config.get('advisor'):
        expand_path(experiment_config['advisor'], 'codeDir')
58
59
60
    if experiment_config.get('machineList'):
        for index in range(len(experiment_config['machineList'])):
            expand_path(experiment_config['machineList'][index], 'sshKeyPath')
SparkSnail's avatar
SparkSnail committed
61
62
    if experiment_config['trial'].get('paiConfigPath'):
        expand_path(experiment_config['trial'], 'paiConfigPath')
63

64
65
66
67
68
69
    #if users use relative path, convert it to absolute path
    root_path = os.path.dirname(config_path)
    if experiment_config.get('searchSpacePath'):
        parse_relative_path(root_path, experiment_config, 'searchSpacePath')
    if experiment_config.get('trial'):
        parse_relative_path(root_path, experiment_config['trial'], 'codeDir')
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        if experiment_config['trial'].get('authFile'):
            parse_relative_path(root_path, experiment_config['trial'], 'authFile')
        if experiment_config['trial'].get('ps'):
            if experiment_config['trial']['ps'].get('privateRegistryAuthPath'):
                parse_relative_path(root_path, experiment_config['trial']['ps'], 'privateRegistryAuthPath')
        if experiment_config['trial'].get('master'):
            if experiment_config['trial']['master'].get('privateRegistryAuthPath'):
                parse_relative_path(root_path, experiment_config['trial']['master'], 'privateRegistryAuthPath')
        if experiment_config['trial'].get('worker'):
            if experiment_config['trial']['worker'].get('privateRegistryAuthPath'):
                parse_relative_path(root_path, experiment_config['trial']['worker'], 'privateRegistryAuthPath')
        if experiment_config['trial'].get('taskRoles'):
            for index in range(len(experiment_config['trial']['taskRoles'])):
                if experiment_config['trial']['taskRoles'][index].get('privateRegistryAuthPath'):
                    parse_relative_path(root_path, experiment_config['trial']['taskRoles'][index], 'privateRegistryAuthPath')
85
86
87
88
    if experiment_config.get('tuner'):
        parse_relative_path(root_path, experiment_config['tuner'], 'codeDir')
    if experiment_config.get('assessor'):
        parse_relative_path(root_path, experiment_config['assessor'], 'codeDir')
QuanluZhang's avatar
QuanluZhang committed
89
90
    if experiment_config.get('advisor'):
        parse_relative_path(root_path, experiment_config['advisor'], 'codeDir')
SparkSnail's avatar
SparkSnail committed
91
92
93
    if experiment_config.get('machineList'):
        for index in range(len(experiment_config['machineList'])):
            parse_relative_path(root_path, experiment_config['machineList'][index], 'sshKeyPath')
SparkSnail's avatar
SparkSnail committed
94
95
    if experiment_config['trial'].get('paiConfigPath'):
        parse_relative_path(root_path, experiment_config['trial'], 'paiConfigPath')
96

chicm-ms's avatar
chicm-ms committed
97
def set_default_values(experiment_config):
98
99
100
101
102
103
104
105
    if experiment_config.get('maxExecDuration') is None:
        experiment_config['maxExecDuration'] = '999d'
    if experiment_config.get('maxTrialNum') is None:
        experiment_config['maxTrialNum'] = 99999
    if experiment_config['trainingServicePlatform'] == 'remote':
        for index in range(len(experiment_config['machineList'])):
            if experiment_config['machineList'][index].get('port') is None:
                experiment_config['machineList'][index]['port'] = 22
Deshui Yu's avatar
Deshui Yu committed
106

107
def validate_all_content(experiment_config, config_path):
Deshui Yu's avatar
Deshui Yu committed
108
    '''Validate whether experiment_config is valid'''
109
    parse_path(experiment_config, config_path)
chicm-ms's avatar
chicm-ms committed
110
111
112
113
    set_default_values(experiment_config)

    NNIConfigSchema().validate(experiment_config)

114
    experiment_config['maxExecDuration'] = parse_time(experiment_config['maxExecDuration'])