Unverified Commit d5857823 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Config refactor (#4370)

parent cb090e8c
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from nni.experiment.config.base import ConfigBase
# config classes
@dataclass(init=False)
class NestedChild(ConfigBase):
msg: str
int_field: int = 1
def _canonicalize(self, parents):
if '/' not in self.msg:
self.msg = parents[0].msg + '/' + self.msg
super()._canonicalize(parents)
def _validate_canonical(self):
super()._validate_canonical()
if not self.msg.endswith('[2]'):
raise ValueError('not end with [2]')
@dataclass(init=False)
class Child(ConfigBase):
msg: str
children: List[NestedChild]
def _canonicalize(self, parents):
if '/' not in self.msg:
self.msg = parents[0].msg + '/' + self.msg
super()._canonicalize(parents)
def _validate_canonical(self):
super()._validate_canonical()
if not self.msg.endswith('[1]'):
raise ValueError('not end with "[1]"')
@dataclass(init=False)
class TestConfig(ConfigBase):
msg: str
required_field: Optional[int]
optional_field: Optional[int] = None
multi_type_field: Union[int, List[int]]
child: Optional[Child] = None
def _canonicalize(self, parents):
if isinstance(self.multi_type_field, int):
self.multi_type_field = [self.multi_type_field]
super()._canonicalize(parents)
# sample inputs
good = {
'msg': 'a',
'required_field': 10,
'multi_type_field': 20,
'child': {
'msg': 'b[1]',
'children': [{
'msg': 'c[2]',
'int_field': 30,
}, {
'msg': 'd[2]',
}],
},
}
missing = deepcopy(good)
missing.pop('required_field')
wrong_type = deepcopy(good)
wrong_type['optional_field'] = 0.5
nested_wrong_type = deepcopy(good)
nested_wrong_type['child']['children'][1]['int_field'] = 'str'
bad_value = deepcopy(good)
bad_value['child']['msg'] = 'b'
extra_field = deepcopy(good)
extra_field['hello'] = 'world'
bads = {
'missing': missing,
'wrong_type': wrong_type,
'nested_wrong_type': nested_wrong_type,
'bad_value': bad_value,
'extra_field': extra_field,
}
# ground truth
_nested_child_1 = NestedChild()
_nested_child_1.msg = 'c[2]'
_nested_child_1.int_field = 30
_nested_child_2 = NestedChild()
_nested_child_2.msg = 'd[2]'
_nested_child_2.int_field = 1
_child = Child()
_child.msg = 'b[1]'
_child.children = [_nested_child_1, _nested_child_2]
good_config = TestConfig()
good_config.msg = 'a'
good_config.required_field = 10
good_config.optional_field = None
good_config.multi_type_field = 20
good_config.child = _child
_nested_child_1 = NestedChild()
_nested_child_1.msg = 'a/b[1]/c[2]'
_nested_child_1.int_field = 30
_nested_child_2 = NestedChild()
_nested_child_2.msg = 'a/b[1]/d[2]'
_nested_child_2.int_field = 1
_child = Child()
_child.msg = 'a/b[1]'
_child.children = [_nested_child_1, _nested_child_2]
good_canon_config = TestConfig()
good_canon_config.msg = 'a'
good_canon_config.required_field = 10
good_canon_config.optional_field = None
good_canon_config.multi_type_field = [20]
good_canon_config.child = _child
# test function
def test_good():
config = TestConfig(**good)
assert config == good_config
config.validate()
assert config.json() == good_canon_config.json()
def test_bad():
for tag, bad in bads.items():
exc = None
try:
config = TestConfig(**bad)
config.validate()
except Exception as e:
exc = e
assert isinstance(exc, ValueError), tag
if __name__ == '__main__':
test_good()
test_bad()
import os.path
from pathlib import Path
from nni.experiment.config import ExperimentConfig
def expand_path(path):
return os.path.realpath(os.path.join(os.path.dirname(__file__), path))
## minimal config ##
minimal_json = {
'searchSpace': {'a': 1},
'trialCommand': 'python main.py',
'trialConcurrency': 2,
'tuner': {
'name': 'random',
},
'trainingService': {
'platform': 'local',
},
}
minimal_class = ExperimentConfig('local')
minimal_class.search_space = {'a': 1}
minimal_class.trial_command = 'python main.py'
minimal_class.trial_concurrency = 2
minimal_class.tuner.name = 'random'
minimal_canon = {
'searchSpace': {'a': 1},
'trialCommand': 'python main.py',
'trialCodeDirectory': os.path.realpath('.'),
'trialConcurrency': 2,
'useAnnotation': False,
'debug': False,
'logLevel': 'info',
'experimentWorkingDirectory': str(Path.home() / 'nni-experiments'),
'tuner': {'name': 'random'},
'trainingService': {
'platform': 'local',
'trialCommand': 'python main.py',
'trialCodeDirectory': os.path.realpath('.'),
'debug': False,
'maxTrialNumberPerGpu': 1,
'reuseMode': False,
},
}
## detailed config ##
detailed_canon = {
'experimentName': 'test case',
'searchSpaceFile': expand_path('assets/search_space.json'),
'searchSpace': {'a': 1},
'trialCommand': 'python main.py',
'trialCodeDirectory': expand_path('assets'),
'trialConcurrency': 2,
'trialGpuNumber': 1,
'maxExperimentDuration': '1.5h',
'maxTrialNumber': 10,
'maxTrialDuration': 60,
'nniManagerIp': '1.2.3.4',
'useAnnotation': False,
'debug': True,
'logLevel': 'warning',
'experimentWorkingDirectory': str(Path.home() / 'nni-experiments'),
'tunerGpuIndices': [0],
'assessor': {
'name': 'assess',
},
'advisor': {
'className': 'Advisor',
'codeDirectory': expand_path('assets'),
'classArgs': {'random_seed': 0},
},
'trainingService': {
'platform': 'local',
'trialCommand': 'python main.py',
'trialCodeDirectory': expand_path('assets'),
'trialGpuNumber': 1,
'debug': True,
'useActiveGpu': False,
'maxTrialNumberPerGpu': 2,
'gpuIndices': [1, 2],
'reuseMode': True,
},
'sharedStorage': {
'storageType': 'NFS',
'localMountPoint': expand_path('assets'),
'remoteMountPoint': '/tmp',
'localMounted': 'usermount',
'nfsServer': 'nfs.test.case',
'exportedDirectory': 'root',
},
}
## test function ##
def test_all():
minimal = ExperimentConfig(**minimal_json)
assert minimal.json() == minimal_canon
assert minimal_class.json() == minimal_canon
detailed = ExperimentConfig.load(expand_path('assets/config.yaml'))
assert detailed.json() == detailed_canon
if __name__ == '__main__':
test_all()
import os.path
from pathlib import Path
from nni.experiment.config import ExperimentConfig, AlgorithmConfig, RemoteConfig, RemoteMachineConfig
## minimal config ##
minimal_json = {
'searchSpace': {'a': 1},
'trialCommand': 'python main.py',
'trialConcurrency': 2,
'tuner': {
'name': 'random',
},
'trainingService': {
'platform': 'remote',
'machine_list': [
{
'host': '1.2.3.4',
'user': 'test_user',
'password': '123456',
},
],
},
}
minimal_class = ExperimentConfig(
search_space = {'a': 1},
trial_command = 'python main.py',
trial_concurrency = 2,
tuner = AlgorithmConfig(
name = 'random',
),
training_service = RemoteConfig(
machine_list = [
RemoteMachineConfig(
host = '1.2.3.4',
user = 'test_user',
password = '123456',
),
],
),
)
minimal_canon = {
'searchSpace': {'a': 1},
'trialCommand': 'python main.py',
'trialCodeDirectory': os.path.realpath('.'),
'trialConcurrency': 2,
'useAnnotation': False,
'debug': False,
'logLevel': 'info',
'experimentWorkingDirectory': str(Path.home() / 'nni-experiments'),
'tuner': {
'name': 'random',
},
'trainingService': {
'platform': 'remote',
'trialCommand': 'python main.py',
'trialCodeDirectory': os.path.realpath('.'),
'debug': False,
'machineList': [
{
'host': '1.2.3.4',
'port': 22,
'user': 'test_user',
'password': '123456',
'useActiveGpu': False,
'maxTrialNumberPerGpu': 1,
}
],
'reuseMode': True,
}
}
## detailed config ##
detailed_json = {
'searchSpace': {'a': 1},
'trialCommand': 'python main.py',
'trialConcurrency': 2,
'trialGpuNumber': 1,
'nni_manager_ip': '1.2.3.0',
'tuner': {
'name': 'random',
},
'trainingService': {
'platform': 'remote',
'machine_list': [
{
'host': '1.2.3.4',
'user': 'test_user',
'password': '123456',
},
{
'host': '1.2.3.5',
'user': 'test_user_2',
'password': 'abcdef',
'use_active_gpu': True,
'max_trial_number_per_gpu': 2,
'gpu_indices': '0,1',
'python_path': '~/path', # don't do this in actual experiment
},
],
},
}
detailed_canon = {
'searchSpace': {'a': 1},
'trialCommand': 'python main.py',
'trialCodeDirectory': os.path.realpath('.'),
'trialConcurrency': 2,
'trialGpuNumber': 1,
'nniManagerIp': '1.2.3.0',
'useAnnotation': False,
'debug': False,
'logLevel': 'info',
'experimentWorkingDirectory': str(Path.home() / 'nni-experiments'),
'tuner': {'name': 'random'},
'trainingService': {
'platform': 'remote',
'trialCommand': 'python main.py',
'trialCodeDirectory': os.path.realpath('.'),
'trialGpuNumber': 1,
'nniManagerIp': '1.2.3.0',
'debug': False,
'machineList': [
{
'host': '1.2.3.4',
'port': 22,
'user': 'test_user',
'password': '123456',
'useActiveGpu': False,
'maxTrialNumberPerGpu': 1
},
{
'host': '1.2.3.5',
'port': 22,
'user': 'test_user_2',
'password': 'abcdef',
'useActiveGpu': True,
'maxTrialNumberPerGpu': 2,
'gpuIndices': [0, 1],
'pythonPath': '~/path'
}
],
'reuseMode': True,
}
}
## test function ##
def test_remote():
config = ExperimentConfig(**minimal_json)
assert config.json() == minimal_canon
assert minimal_class.json() == minimal_canon
config = ExperimentConfig(**detailed_json)
assert config.json() == detailed_canon
if __name__ == '__main__':
test_remote()
...@@ -8,16 +8,27 @@ import { KubernetesStorageKind } from '../training_service/kubernetes/kubernetes ...@@ -8,16 +8,27 @@ import { KubernetesStorageKind } from '../training_service/kubernetes/kubernetes
export interface TrainingServiceConfig { export interface TrainingServiceConfig {
platform: string; platform: string;
trialCommand: string;
trialCodeDirectory: string;
trialGpuNumber?: number;
nniManagerIp?: string;
// FIXME
// "debug" is only used by openpai to decide whether to check remote nni version
// it should be better to check when local nni version is not "dev"
// it should be even better to check version before launching the experiment and let user to confirm
// log level is currently handled by global logging module and has nothing to do with this
debug?: boolean;
} }
/* Local */ /* Local */
export interface LocalConfig extends TrainingServiceConfig { export interface LocalConfig extends TrainingServiceConfig {
platform: 'local'; platform: 'local';
reuseMode: boolean;
useActiveGpu?: boolean; useActiveGpu?: boolean;
maxTrialNumberPerGpu: number; maxTrialNumberPerGpu: number;
gpuIndices?: number[]; gpuIndices?: number[];
reuseMode: boolean;
} }
/* Remote */ /* Remote */
...@@ -37,8 +48,8 @@ export interface RemoteMachineConfig { ...@@ -37,8 +48,8 @@ export interface RemoteMachineConfig {
export interface RemoteConfig extends TrainingServiceConfig { export interface RemoteConfig extends TrainingServiceConfig {
platform: 'remote'; platform: 'remote';
reuseMode: boolean;
machineList: RemoteMachineConfig[]; machineList: RemoteMachineConfig[];
reuseMode: boolean;
} }
/* OpenPAI */ /* OpenPAI */
...@@ -52,11 +63,11 @@ export interface OpenpaiConfig extends TrainingServiceConfig { ...@@ -52,11 +63,11 @@ export interface OpenpaiConfig extends TrainingServiceConfig {
trialMemorySize: string; trialMemorySize: string;
storageConfigName: string; storageConfigName: string;
dockerImage: string; dockerImage: string;
virtualCluster?: string;
localStorageMountPoint: string; localStorageMountPoint: string;
containerStorageMountPoint: string; containerStorageMountPoint: string;
reuseMode: boolean; reuseMode: boolean;
openpaiConfig?: object; openpaiConfig?: object;
virtualCluster?: string;
} }
/* AML */ /* AML */
...@@ -89,10 +100,8 @@ export interface DlcConfig extends TrainingServiceConfig { ...@@ -89,10 +100,8 @@ export interface DlcConfig extends TrainingServiceConfig {
} }
/* Kubeflow */ /* Kubeflow */
// FIXME: merge with shared storage config
export interface KubernetesStorageConfig { export interface KubernetesStorageConfig {
storageType: string; storageType: string;
maxTrialNumberPerGpu?: number;
server?: string; server?: string;
path?: string; path?: string;
azureAccount?: string; azureAccount?: string;
...@@ -103,51 +112,51 @@ export interface KubernetesStorageConfig { ...@@ -103,51 +112,51 @@ export interface KubernetesStorageConfig {
export interface KubeflowRoleConfig { export interface KubeflowRoleConfig {
replicas: number; replicas: number;
codeDirectory: string;
command: string; command: string;
gpuNumber: number; gpuNumber: number;
cpuNumber: number; cpuNumber: number;
memorySize: number; memorySize: string | number;
dockerImage: string; dockerImage: string;
codeDirectory: string;
privateRegistryAuthPath?: string; privateRegistryAuthPath?: string;
} }
export interface KubeflowConfig extends TrainingServiceConfig { export interface KubeflowConfig extends TrainingServiceConfig {
platform: 'kubeflow'; platform: 'kubeflow';
ps?: KubeflowRoleConfig;
master?: KubeflowRoleConfig;
worker?: KubeflowRoleConfig;
maxTrialNumberPerGpu: number;
operator: KubeflowOperator; operator: KubeflowOperator;
apiVersion: OperatorApiVersion; apiVersion: OperatorApiVersion;
storage: KubernetesStorageConfig; storage: KubernetesStorageConfig;
worker?: KubeflowRoleConfig;
ps?: KubeflowRoleConfig;
master?: KubeflowRoleConfig;
reuseMode: boolean; reuseMode: boolean;
maxTrialNumberPerGpu?: number;
} }
export interface FrameworkControllerTaskRoleConfig { export interface FrameworkControllerTaskRoleConfig {
name: string; name: string;
dockerImage: string;
taskNumber: number; taskNumber: number;
command: string; command: string;
gpuNumber: number; gpuNumber: number;
cpuNumber: number; cpuNumber: number;
memorySize: number; memorySize: string | number;
dockerImage: string;
privateRegistryAuthPath?: string;
frameworkAttemptCompletionPolicy: { frameworkAttemptCompletionPolicy: {
minFailedTaskCount: number; minFailedTaskCount: number;
minSucceedTaskCount: number; minSucceedTaskCount: number;
}; };
privateRegistryAuthPath?: string;
} }
export interface FrameworkControllerConfig extends TrainingServiceConfig { export interface FrameworkControllerConfig extends TrainingServiceConfig {
platform: 'frameworkcontroller'; platform: 'frameworkcontroller';
taskRoles: FrameworkControllerTaskRoleConfig[];
maxTrialNumberPerGpu: number;
storage: KubernetesStorageConfig; storage: KubernetesStorageConfig;
reuseMode: boolean;
namespace: 'default';
apiVersion: string;
serviceAccountName: string; serviceAccountName: string;
taskRoles: FrameworkControllerTaskRoleConfig[];
reuseMode: boolean;
maxTrialNumberPerGpu?: number;
namespace?: 'default';
apiVersion?: string;
} }
/* shared storage */ /* shared storage */
...@@ -182,16 +191,17 @@ export interface AlgorithmConfig { ...@@ -182,16 +191,17 @@ export interface AlgorithmConfig {
export interface ExperimentConfig { export interface ExperimentConfig {
experimentName?: string; experimentName?: string;
// searchSpaceFile (handled in python part)
searchSpace: any; searchSpace: any;
trialCommand: string; trialCommand: string;
trialCodeDirectory: string; trialCodeDirectory: string;
trialConcurrency: number; trialConcurrency: number;
trialGpuNumber?: number; trialGpuNumber?: number;
maxExperimentDuration?: string; maxExperimentDuration?: string | number;
maxTrialDuration?: string;
maxTrialNumber?: number; maxTrialNumber?: number;
maxTrialDuration?: string | number;
nniManagerIp?: string; nniManagerIp?: string;
//useAnnotation: boolean; // dealed inside nnictl // useAnnotation (handled in python part)
debug: boolean; debug: boolean;
logLevel?: string; logLevel?: string;
experimentWorkingDirectory?: string; experimentWorkingDirectory?: string;
...@@ -207,45 +217,31 @@ export interface ExperimentConfig { ...@@ -207,45 +217,31 @@ export interface ExperimentConfig {
/* util functions */ /* util functions */
const timeUnits = { d: 24 * 3600, h: 3600, m: 60, s: 1 }; const timeUnits = { d: 24 * 3600, h: 3600, m: 60, s: 1 };
const sizeUnits = { tb: 1024 ** 4, gb: 1024 ** 3, mb: 1024 ** 2, kb: 1024, b: 1 };
export function toSeconds(time: string): number { function toUnit(value: string | number, targetUnit: string, allUnits: any): number {
for (const [unit, factor] of Object.entries(timeUnits)) { if (typeof value === 'number') {
if (time.toLowerCase().endsWith(unit)) { return value;
const digits = time.slice(0, -1); }
return Number(digits) * factor; value = value.toLowerCase();
for (const [unit, factor] of Object.entries(allUnits)) {
if (value.endsWith(unit)) {
const digits = value.slice(0, -unit.length);
const num = Number(digits) * (factor as number);
return Math.ceil(num / allUnits[targetUnit]);
} }
} }
throw new Error(`Bad time string "${time}"`); throw new Error(`Bad unit in "${value}"`);
} }
const sizeUnits = { tb: 1024 * 1024, gb: 1024, mb: 1, kb: 1 / 1024 }; export function toSeconds(time: string | number): number {
return toUnit(time, 's', timeUnits);
export function toMegaBytes(size: string): number {
for (const [unit, factor] of Object.entries(sizeUnits)) {
if (size.toLowerCase().endsWith(unit)) {
const digits = size.slice(0, -2);
return Math.floor(Number(digits) * factor);
}
}
throw new Error(`Bad size string "${size}"`);
} }
export function toCudaVisibleDevices(gpuIndices?: number[]): string { export function toMegaBytes(size: string | number): number {
return gpuIndices === undefined ? '' : gpuIndices.join(','); return toUnit(size, 'mb', sizeUnits);
} }
export function flattenConfig<T>(config: ExperimentConfig, platform: string): T { export function toCudaVisibleDevices(gpuIndices?: number[]): string {
const flattened = { }; return gpuIndices === undefined ? '' : gpuIndices.join(',');
Object.assign(flattened, config);
if (Array.isArray(config.trainingService)) {
for (const trainingService of config.trainingService) {
if (trainingService.platform === platform) {
Object.assign(flattened, trainingService);
}
}
} else {
assert(config.trainingService.platform === platform);
Object.assign(flattened, config.trainingService);
}
return <T>flattened;
} }
...@@ -13,7 +13,7 @@ import { ...@@ -13,7 +13,7 @@ import {
ExperimentProfile, Manager, ExperimentStatus, ExperimentProfile, Manager, ExperimentStatus,
NNIManagerStatus, ProfileUpdateType, TrialJobStatistics NNIManagerStatus, ProfileUpdateType, TrialJobStatistics
} from '../common/manager'; } from '../common/manager';
import { ExperimentConfig, toSeconds, toCudaVisibleDevices } from '../common/experimentConfig'; import { ExperimentConfig, LocalConfig, toSeconds, toCudaVisibleDevices } from '../common/experimentConfig';
import { ExperimentManager } from '../common/experimentManager'; import { ExperimentManager } from '../common/experimentManager';
import { TensorboardManager } from '../common/tensorboardManager'; import { TensorboardManager } from '../common/tensorboardManager';
import { import {
...@@ -454,7 +454,7 @@ class NNIManager implements Manager { ...@@ -454,7 +454,7 @@ class NNIManager implements Manager {
return await module_.RouterTrainingService.construct(config); return await module_.RouterTrainingService.construct(config);
} else if (platform === 'local') { } else if (platform === 'local') {
const module_ = await import('../training_service/local/localTrainingService'); const module_ = await import('../training_service/local/localTrainingService');
return new module_.LocalTrainingService(config); return new module_.LocalTrainingService(<LocalConfig>config.trainingService);
} else if (platform === 'kubeflow') { } else if (platform === 'kubeflow') {
const module_ = await import('../training_service/kubernetes/kubeflow/kubeflowTrainingService'); const module_ = await import('../training_service/kubernetes/kubeflow/kubeflowTrainingService');
return new module_.KubeflowTrainingService(); return new module_.KubeflowTrainingService();
......
...@@ -72,12 +72,12 @@ if (!strPort || strPort.length === 0) { ...@@ -72,12 +72,12 @@ if (!strPort || strPort.length === 0) {
} }
const foregroundArg: string = parseArg(['--foreground', '-f']); const foregroundArg: string = parseArg(['--foreground', '-f']);
if (!('true' || 'false').includes(foregroundArg.toLowerCase())) { if (foregroundArg && !['true', 'false'].includes(foregroundArg.toLowerCase())) {
console.log(`FATAL: foreground property should only be true or false`); console.log(`FATAL: foreground property should only be true or false`);
usage(); usage();
process.exit(1); process.exit(1);
} }
const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : false; const foreground: boolean = (foregroundArg && foregroundArg.toLowerCase() === 'true') ? true : false;
const port: number = parseInt(strPort, 10); const port: number = parseInt(strPort, 10);
...@@ -107,12 +107,12 @@ if (logDir.length > 0) { ...@@ -107,12 +107,12 @@ if (logDir.length > 0) {
const logLevel: string = parseArg(['--log_level', '-ll']); const logLevel: string = parseArg(['--log_level', '-ll']);
const readonlyArg: string = parseArg(['--readonly', '-r']); const readonlyArg: string = parseArg(['--readonly', '-r']);
if (!('true' || 'false').includes(readonlyArg.toLowerCase())) { if (readonlyArg && !['true', 'false'].includes(readonlyArg.toLowerCase())) {
console.log(`FATAL: readonly property should only be true or false`); console.log(`FATAL: readonly property should only be true or false`);
usage(); usage();
process.exit(1); process.exit(1);
} }
const readonly = readonlyArg.toLowerCase() == 'true' ? true : false; const readonly = (readonlyArg && readonlyArg.toLowerCase() == 'true') ? true : false;
const dispatcherPipe: string = parseArg(['--dispatcher_pipe']); const dispatcherPipe: string = parseArg(['--dispatcher_pipe']);
......
...@@ -36,7 +36,7 @@ describe('Unit test for dataStore', () => { ...@@ -36,7 +36,7 @@ describe('Unit test for dataStore', () => {
}); });
it('test experiment profiles CRUD', async () => { it('test experiment profiles CRUD', async () => {
const profile: ExperimentProfile = { const profile: ExperimentProfile = <ExperimentProfile>{
params: { params: {
experimentName: 'exp1', experimentName: 'exp1',
trialConcurrency: 2, trialConcurrency: 2,
......
...@@ -20,7 +20,7 @@ function startProcess(): void { ...@@ -20,7 +20,7 @@ function startProcess(): void {
const dispatcherCmd: string = getMsgDispatcherCommand( const dispatcherCmd: string = getMsgDispatcherCommand(
// Mock tuner config // Mock tuner config
{ <any>{
experimentName: 'exp1', experimentName: 'exp1',
maxExperimentDuration: '1h', maxExperimentDuration: '1h',
searchSpace: '', searchSpace: '',
......
...@@ -40,7 +40,7 @@ describe('Unit test for nnimanager', function () { ...@@ -40,7 +40,7 @@ describe('Unit test for nnimanager', function () {
let ClusterMetadataKey = 'mockedMetadataKey'; let ClusterMetadataKey = 'mockedMetadataKey';
let experimentParams = { let experimentParams: any = {
experimentName: 'naive_experiment', experimentName: 'naive_experiment',
trialConcurrency: 3, trialConcurrency: 3,
maxExperimentDuration: '5s', maxExperimentDuration: '5s',
...@@ -86,7 +86,7 @@ describe('Unit test for nnimanager', function () { ...@@ -86,7 +86,7 @@ describe('Unit test for nnimanager', function () {
debug: true debug: true
} }
let experimentProfile = { let experimentProfile: any = {
params: updateExperimentParams, params: updateExperimentParams,
id: 'test', id: 'test',
execDuration: 0, execDuration: 0,
......
...@@ -14,7 +14,7 @@ import { ExperimentConfig, ExperimentProfile } from '../../common/manager'; ...@@ -14,7 +14,7 @@ import { ExperimentConfig, ExperimentProfile } from '../../common/manager';
import { cleanupUnitTest, getDefaultDatabaseDir, mkDirP, prepareUnitTest } from '../../common/utils'; import { cleanupUnitTest, getDefaultDatabaseDir, mkDirP, prepareUnitTest } from '../../common/utils';
import { SqlDB } from '../../core/sqlDatabase'; import { SqlDB } from '../../core/sqlDatabase';
const expParams1: ExperimentConfig = { const expParams1: ExperimentConfig = <any>{
experimentName: 'Exp1', experimentName: 'Exp1',
trialConcurrency: 3, trialConcurrency: 3,
maxExperimentDuration: '100s', maxExperimentDuration: '100s',
...@@ -31,7 +31,7 @@ const expParams1: ExperimentConfig = { ...@@ -31,7 +31,7 @@ const expParams1: ExperimentConfig = {
debug: true debug: true
}; };
const expParams2: ExperimentConfig = { const expParams2: ExperimentConfig = <any>{
experimentName: 'Exp2', experimentName: 'Exp2',
trialConcurrency: 5, trialConcurrency: 5,
maxExperimentDuration: '1000s', maxExperimentDuration: '1000s',
......
...@@ -133,7 +133,7 @@ export class MockedNNIManager extends Manager { ...@@ -133,7 +133,7 @@ export class MockedNNIManager extends Manager {
throw new MethodNotImplementedError(); throw new MethodNotImplementedError();
} }
public getExperimentProfile(): Promise<ExperimentProfile> { public getExperimentProfile(): Promise<ExperimentProfile> {
const profile: ExperimentProfile = { const profile: ExperimentProfile = <any>{
params: { params: {
experimentName: 'exp1', experimentName: 'exp1',
trialConcurrency: 2, trialConcurrency: 2,
......
...@@ -13,7 +13,6 @@ import { TrialJobApplicationForm, TrialJobDetail} from '../../common/trainingSer ...@@ -13,7 +13,6 @@ import { TrialJobApplicationForm, TrialJobDetail} from '../../common/trainingSer
import { cleanupUnitTest, delay, prepareUnitTest, getExperimentRootDir } from '../../common/utils'; import { cleanupUnitTest, delay, prepareUnitTest, getExperimentRootDir } from '../../common/utils';
import { TrialConfigMetadataKey } from '../../training_service/common/trialConfigMetadataKey'; import { TrialConfigMetadataKey } from '../../training_service/common/trialConfigMetadataKey';
import { LocalTrainingService } from '../../training_service/local/localTrainingService'; import { LocalTrainingService } from '../../training_service/local/localTrainingService';
import { ExperimentConfig } from '../../common/experimentConfig';
// TODO: copy mockedTrail.py to local folder // TODO: copy mockedTrail.py to local folder
const localCodeDir: string = tmp.dirSync().name.split('\\').join('\\\\'); const localCodeDir: string = tmp.dirSync().name.split('\\').join('\\\\');
...@@ -21,22 +20,22 @@ const mockedTrialPath: string = './test/mock/mockedTrial.py' ...@@ -21,22 +20,22 @@ const mockedTrialPath: string = './test/mock/mockedTrial.py'
fs.copyFileSync(mockedTrialPath, localCodeDir + '/mockedTrial.py') fs.copyFileSync(mockedTrialPath, localCodeDir + '/mockedTrial.py')
describe('Unit Test for LocalTrainingService', () => { describe('Unit Test for LocalTrainingService', () => {
const config = <ExperimentConfig>{ const config = <any>{
platform: 'local',
trialCommand: 'sleep 1h && echo hello', trialCommand: 'sleep 1h && echo hello',
trialCodeDirectory: `${localCodeDir}`, trialCodeDirectory: `${localCodeDir}`,
trialGpuNumber: 0, // TODO: add test case for gpu? trialGpuNumber: 0, // TODO: add test case for gpu?
trainingService: { maxTrialNumberPerGpu: 1,
platform: 'local' reuseMode: true,
}
}; };
const config2 = <ExperimentConfig>{ const config2 = <any>{
platform: 'local',
trialCommand: 'python3 mockedTrial.py', trialCommand: 'python3 mockedTrial.py',
trialCodeDirectory: `${localCodeDir}`, trialCodeDirectory: `${localCodeDir}`,
trialGpuNumber: 0, trialGpuNumber: 0,
trainingService: { maxTrialNumberPerGpu: 1,
platform: 'local' reuseMode: true,
}
}; };
before(() => { before(() => {
......
...@@ -169,7 +169,7 @@ async function waitEnvironment(waitCount: number, ...@@ -169,7 +169,7 @@ async function waitEnvironment(waitCount: number,
return waitRequestEnvironment; return waitRequestEnvironment;
} }
const config = { const config: any = {
searchSpace: { }, searchSpace: { },
trialCommand: 'echo hi', trialCommand: 'echo hi',
trialCodeDirectory: path.dirname(__filename), trialCodeDirectory: path.dirname(__filename),
......
...@@ -5,8 +5,7 @@ import assert from 'assert'; ...@@ -5,8 +5,7 @@ import assert from 'assert';
import { import {
AzureStorage, KeyVaultConfig, KubernetesClusterConfig, KubernetesClusterConfigAzure, KubernetesClusterConfigNFS, AzureStorage, KeyVaultConfig, KubernetesClusterConfig, KubernetesClusterConfigAzure, KubernetesClusterConfigNFS,
KubernetesStorageKind, KubernetesTrialConfig, KubernetesTrialConfigTemplate, NFSConfig, StorageConfig, KubernetesClusterConfigPVC, KubernetesStorageKind, KubernetesTrialConfig, KubernetesTrialConfigTemplate, NFSConfig, StorageConfig
PVCConfig,
} from '../kubernetesConfig'; } from '../kubernetesConfig';
export class FrameworkAttemptCompletionPolicy { export class FrameworkAttemptCompletionPolicy {
...@@ -53,31 +52,6 @@ export class FrameworkControllerClusterConfig extends KubernetesClusterConfig { ...@@ -53,31 +52,6 @@ export class FrameworkControllerClusterConfig extends KubernetesClusterConfig {
} }
} }
export class FrameworkControllerClusterConfigPVC extends KubernetesClusterConfigPVC {
public readonly serviceAccountName: string;
public readonly configPath: string;
constructor(serviceAccountName: string, apiVersion: string, pvc: PVCConfig, configPath: string,
storage?: KubernetesStorageKind, namespace?: string) {
super(apiVersion, pvc, storage, namespace);
this.serviceAccountName = serviceAccountName;
this.configPath = configPath
}
public static getInstance(jsonObject: object): FrameworkControllerClusterConfigPVC {
const kubernetesClusterConfigObjectPVC: FrameworkControllerClusterConfigPVC = <FrameworkControllerClusterConfigPVC>jsonObject;
assert(kubernetesClusterConfigObjectPVC !== undefined);
return new FrameworkControllerClusterConfigPVC(
kubernetesClusterConfigObjectPVC.serviceAccountName,
kubernetesClusterConfigObjectPVC.apiVersion,
kubernetesClusterConfigObjectPVC.pvc,
kubernetesClusterConfigObjectPVC.configPath,
kubernetesClusterConfigObjectPVC.storage,
kubernetesClusterConfigObjectPVC.namespace
);
}
}
export class FrameworkControllerClusterConfigNFS extends KubernetesClusterConfigNFS { export class FrameworkControllerClusterConfigNFS extends KubernetesClusterConfigNFS {
public readonly serviceAccountName: string; public readonly serviceAccountName: string;
public readonly configPath?: string; public readonly configPath?: string;
...@@ -153,8 +127,6 @@ export class FrameworkControllerClusterConfigFactory { ...@@ -153,8 +127,6 @@ export class FrameworkControllerClusterConfigFactory {
return FrameworkControllerClusterConfigAzure.getInstance(jsonObject); return FrameworkControllerClusterConfigAzure.getInstance(jsonObject);
} else if (storageConfig.storage === undefined || storageConfig.storage === 'nfs') { } else if (storageConfig.storage === undefined || storageConfig.storage === 'nfs') {
return FrameworkControllerClusterConfigNFS.getInstance(jsonObject); return FrameworkControllerClusterConfigNFS.getInstance(jsonObject);
} else if (storageConfig.storage !== undefined && storageConfig.storage === 'pvc') {
return FrameworkControllerClusterConfigPVC.getInstance(jsonObject);
} }
throw new Error(`Invalid json object ${jsonObject}`); throw new Error(`Invalid json object ${jsonObject}`);
} }
......
...@@ -26,7 +26,6 @@ import { ...@@ -26,7 +26,6 @@ import {
FrameworkControllerClusterConfigNFS, FrameworkControllerClusterConfigNFS,
FrameworkControllerTrialConfig, FrameworkControllerTrialConfig,
FrameworkControllerTrialConfigTemplate, FrameworkControllerTrialConfigTemplate,
FrameworkControllerClusterConfigPVC,
} from './frameworkcontrollerConfig'; } from './frameworkcontrollerConfig';
import {FrameworkControllerJobInfoCollector} from './frameworkcontrollerJobInfoCollector'; import {FrameworkControllerJobInfoCollector} from './frameworkcontrollerJobInfoCollector';
import {FrameworkControllerJobRestServer} from './frameworkcontrollerJobRestServer'; import {FrameworkControllerJobRestServer} from './frameworkcontrollerJobRestServer';
...@@ -239,19 +238,6 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -239,19 +238,6 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
nfsFrameworkControllerClusterConfig.nfs.path nfsFrameworkControllerClusterConfig.nfs.path
); );
namespace = nfsFrameworkControllerClusterConfig.namespace namespace = nfsFrameworkControllerClusterConfig.namespace
} else if (this.fcClusterConfig.storageType === 'pvc') {
const pvcFrameworkControllerClusterConfig: FrameworkControllerClusterConfigPVC =
<FrameworkControllerClusterConfigPVC>this.fcClusterConfig;
this.fcTemplate = yaml.safeLoad(
fs.readFileSync(
pvcFrameworkControllerClusterConfig.configPath,
'utf8'
)
);
await this.createPVCStorage(
pvcFrameworkControllerClusterConfig.pvc.path
);
namespace = pvcFrameworkControllerClusterConfig.namespace;
} }
namespace = namespace ? namespace : "default"; namespace = namespace ? namespace : "default";
this.kubernetesCRDClient = FrameworkControllerClientFactory.createClient(namespace); this.kubernetesCRDClient = FrameworkControllerClientFactory.createClient(namespace);
......
...@@ -18,7 +18,7 @@ import { ...@@ -18,7 +18,7 @@ import {
import { import {
delay, generateParamFileName, getExperimentRootDir, getJobCancelStatus, getNewLine, isAlive, uniqueString delay, generateParamFileName, getExperimentRootDir, getJobCancelStatus, getNewLine, isAlive, uniqueString
} from 'common/utils'; } from 'common/utils';
import { ExperimentConfig, LocalConfig, flattenConfig } from 'common/experimentConfig'; import { LocalConfig } from 'common/experimentConfig';
import { execMkdir, execNewFile, getScriptName, runScript, setEnvironmentVariable } from '../common/util'; import { execMkdir, execNewFile, getScriptName, runScript, setEnvironmentVariable } from '../common/util';
import { GPUScheduler } from './gpuScheduler'; import { GPUScheduler } from './gpuScheduler';
...@@ -73,13 +73,11 @@ class LocalTrialJobDetail implements TrialJobDetail { ...@@ -73,13 +73,11 @@ class LocalTrialJobDetail implements TrialJobDetail {
} }
} }
interface FlattenLocalConfig extends ExperimentConfig, LocalConfig { }
/** /**
* Local machine training service * Local machine training service
*/ */
class LocalTrainingService implements TrainingService { class LocalTrainingService implements TrainingService {
private readonly config: FlattenLocalConfig; private readonly config: LocalConfig;
private readonly eventEmitter: EventEmitter; private readonly eventEmitter: EventEmitter;
private readonly jobMap: Map<string, LocalTrialJobDetail>; private readonly jobMap: Map<string, LocalTrialJobDetail>;
private readonly jobQueue: string[]; private readonly jobQueue: string[];
...@@ -92,8 +90,8 @@ class LocalTrainingService implements TrainingService { ...@@ -92,8 +90,8 @@ class LocalTrainingService implements TrainingService {
private readonly log: Logger; private readonly log: Logger;
private readonly jobStreamMap: Map<string, ts.Stream>; private readonly jobStreamMap: Map<string, ts.Stream>;
constructor(config: ExperimentConfig) { constructor(config: LocalConfig) {
this.config = flattenConfig<FlattenLocalConfig>(config, 'local'); this.config = config;
this.eventEmitter = new EventEmitter(); this.eventEmitter = new EventEmitter();
this.jobMap = new Map<string, LocalTrialJobDetail>(); this.jobMap = new Map<string, LocalTrialJobDetail>();
this.jobQueue = []; this.jobQueue = [];
......
...@@ -6,11 +6,9 @@ import { Deferred } from 'ts-deferred'; ...@@ -6,11 +6,9 @@ import { Deferred } from 'ts-deferred';
import { NNIError, NNIErrorNames } from 'common/errors'; import { NNIError, NNIErrorNames } from 'common/errors';
import { getLogger, Logger } from 'common/log'; import { getLogger, Logger } from 'common/log';
import { TrialJobStatus } from 'common/trainingService'; import { TrialJobStatus } from 'common/trainingService';
import { ExperimentConfig, OpenpaiConfig } from 'common/experimentConfig'; import { OpenpaiConfig } from 'common/experimentConfig';
import { PAITrialJobDetail } from './paiConfig'; import { PAITrialJobDetail } from './paiConfig';
interface FlattenOpenpaiConfig extends ExperimentConfig, OpenpaiConfig { }
/** /**
* Collector PAI jobs info from PAI cluster, and update pai job status locally * Collector PAI jobs info from PAI cluster, and update pai job status locally
*/ */
...@@ -26,7 +24,7 @@ export class PAIJobInfoCollector { ...@@ -26,7 +24,7 @@ export class PAIJobInfoCollector {
this.finalStatuses = ['SUCCEEDED', 'FAILED', 'USER_CANCELED', 'SYS_CANCELED', 'EARLY_STOPPED']; this.finalStatuses = ['SUCCEEDED', 'FAILED', 'USER_CANCELED', 'SYS_CANCELED', 'EARLY_STOPPED'];
} }
public async retrieveTrialStatus(protocol: string, token? : string, config?: FlattenOpenpaiConfig): Promise<void> { public async retrieveTrialStatus(protocol: string, token? : string, config?: OpenpaiConfig): Promise<void> {
if (config === undefined || token === undefined) { if (config === undefined || token === undefined) {
return Promise.resolve(); return Promise.resolve();
} }
...@@ -42,7 +40,7 @@ export class PAIJobInfoCollector { ...@@ -42,7 +40,7 @@ export class PAIJobInfoCollector {
await Promise.all(updatePaiTrialJobs); await Promise.all(updatePaiTrialJobs);
} }
private getSinglePAITrialJobInfo(_protocol: string, paiTrialJob: PAITrialJobDetail, paiToken: string, config: FlattenOpenpaiConfig): Promise<void> { private getSinglePAITrialJobInfo(_protocol: string, paiTrialJob: PAITrialJobDetail, paiToken: string, config: OpenpaiConfig): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
if (!this.statusesNeedToCheck.includes(paiTrialJob.status)) { if (!this.statusesNeedToCheck.includes(paiTrialJob.status)) {
deferred.resolve(); deferred.resolve();
......
...@@ -16,7 +16,7 @@ import { ...@@ -16,7 +16,7 @@ import {
TrialJobApplicationForm, TrialJobDetail, TrialJobMetric TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from 'common/trainingService'; } from 'common/trainingService';
import { delay } from 'common/utils'; import { delay } from 'common/utils';
import { ExperimentConfig, OpenpaiConfig, flattenConfig, toMegaBytes } from 'common/experimentConfig'; import { OpenpaiConfig, toMegaBytes } from 'common/experimentConfig';
import { PAIJobInfoCollector } from './paiJobInfoCollector'; import { PAIJobInfoCollector } from './paiJobInfoCollector';
import { PAIJobRestServer } from './paiJobRestServer'; import { PAIJobRestServer } from './paiJobRestServer';
import { PAITrialJobDetail, PAI_TRIAL_COMMAND_FORMAT } from './paiConfig'; import { PAITrialJobDetail, PAI_TRIAL_COMMAND_FORMAT } from './paiConfig';
...@@ -27,8 +27,6 @@ import { execMkdir, validateCodeDir, execCopydir } from '../common/util'; ...@@ -27,8 +27,6 @@ import { execMkdir, validateCodeDir, execCopydir } from '../common/util';
const yaml = require('js-yaml'); const yaml = require('js-yaml');
interface FlattenOpenpaiConfig extends ExperimentConfig, OpenpaiConfig { }
/** /**
* Training Service implementation for OpenPAI (Open Platform for AI) * Training Service implementation for OpenPAI (Open Platform for AI)
* Refer https://github.com/Microsoft/pai for more info about OpenPAI * Refer https://github.com/Microsoft/pai for more info about OpenPAI
...@@ -55,9 +53,9 @@ class PAITrainingService implements TrainingService { ...@@ -55,9 +53,9 @@ class PAITrainingService implements TrainingService {
private copyExpCodeDirPromise?: Promise<void>; private copyExpCodeDirPromise?: Promise<void>;
private paiJobConfig: any; private paiJobConfig: any;
private nniVersion: string | undefined; private nniVersion: string | undefined;
private config: FlattenOpenpaiConfig; private config: OpenpaiConfig;
constructor(config: ExperimentConfig) { constructor(config: OpenpaiConfig) {
this.log = getLogger('PAITrainingService'); this.log = getLogger('PAITrainingService');
this.metricsEmitter = new EventEmitter(); this.metricsEmitter = new EventEmitter();
this.trialJobsMap = new Map<string, PAITrialJobDetail>(); this.trialJobsMap = new Map<string, PAITrialJobDetail>();
...@@ -67,7 +65,7 @@ class PAITrainingService implements TrainingService { ...@@ -67,7 +65,7 @@ class PAITrainingService implements TrainingService {
this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap); this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap);
this.paiTokenUpdateInterval = 7200000; //2hours this.paiTokenUpdateInterval = 7200000; //2hours
this.log.info('Construct paiBase training service.'); this.log.info('Construct paiBase training service.');
this.config = flattenConfig(config, 'openpai'); this.config = config;
this.versionCheck = !this.config.debug; this.versionCheck = !this.config.debug;
this.paiJobRestServer = new PAIJobRestServer(this); this.paiJobRestServer = new PAIJobRestServer(this);
this.paiToken = this.config.token; this.paiToken = this.config.token;
...@@ -221,13 +219,6 @@ class PAITrainingService implements TrainingService { ...@@ -221,13 +219,6 @@ class PAITrainingService implements TrainingService {
protected async statusCheckingLoop(): Promise<void> { protected async statusCheckingLoop(): Promise<void> {
while (!this.stopping) { while (!this.stopping) {
if (this.config.deprecated && this.config.deprecated.password) {
try {
await this.updatePaiToken();
} catch (error) {
this.log.error(`${error}`);
}
}
await this.paiJobCollector.retrieveTrialStatus(this.protocol, this.paiToken, this.config); await this.paiJobCollector.retrieveTrialStatus(this.protocol, this.paiToken, this.config);
if (this.paiJobRestServer === undefined) { if (this.paiJobRestServer === undefined) {
throw new Error('paiBaseJobRestServer not implemented!'); throw new Error('paiBaseJobRestServer not implemented!');
...@@ -239,55 +230,6 @@ class PAITrainingService implements TrainingService { ...@@ -239,55 +230,6 @@ class PAITrainingService implements TrainingService {
} }
} }
/**
* Update pai token by the interval time or initialize the pai token
*/
protected async updatePaiToken(): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
const currentTime: number = new Date().getTime();
//If pai token initialized and not reach the interval time, do not update
if (this.paiTokenUpdateTime !== undefined && (currentTime - this.paiTokenUpdateTime) < this.paiTokenUpdateInterval) {
return Promise.resolve();
}
const authenticationReq: request.Options = {
uri: `${this.config.host}/rest-server/api/v1/token`,
method: 'POST',
json: true,
body: {
username: this.config.username,
password: this.config.deprecated.password
}
};
request(authenticationReq, (error: Error, response: request.Response, body: any) => {
if (error !== undefined && error !== null) {
this.log.error(`Get PAI token failed: ${error.message}`);
deferred.reject(new Error(`Get PAI token failed: ${error.message}`));
} else {
if (response.statusCode !== 200) {
this.log.error(`Get PAI token failed: get PAI Rest return code ${response.statusCode}`);
deferred.reject(new Error(`Get PAI token failed: ${response.body}, please check paiConfig username or password`));
}
this.paiToken = body.token;
this.paiTokenUpdateTime = new Date().getTime();
deferred.resolve();
}
});
let timeoutId: NodeJS.Timer;
const timeoutDelay: Promise<void> = new Promise<void>((_resolve: Function, reject: Function): void => {
// Set timeout and reject the promise once reach timeout (5 seconds)
timeoutId = setTimeout(
() => reject(new Error('Get PAI token timeout. Please check your PAI cluster.')),
5000);
});
return Promise.race([timeoutDelay, deferred.promise])
.finally(() => { clearTimeout(timeoutId); });
}
public async setClusterMetadata(_key: string, _value: string): Promise<void> { return; } public async setClusterMetadata(_key: string, _value: string): Promise<void> { return; }
public async getClusterMetadata(_key: string): Promise<string> { return ''; } public async getClusterMetadata(_key: string): Promise<string> { return ''; }
......
...@@ -20,7 +20,7 @@ import { ...@@ -20,7 +20,7 @@ import {
delay, generateParamFileName, getExperimentRootDir, getIPV4Address, getJobCancelStatus, delay, generateParamFileName, getExperimentRootDir, getIPV4Address, getJobCancelStatus,
getVersion, uniqueString getVersion, uniqueString
} from 'common/utils'; } from 'common/utils';
import { ExperimentConfig, RemoteConfig, RemoteMachineConfig, flattenConfig } from 'common/experimentConfig'; import { RemoteConfig, RemoteMachineConfig } from 'common/experimentConfig';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData'; import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { GPUSummary, ScheduleResultType } from '../common/gpuData'; import { GPUSummary, ScheduleResultType } from '../common/gpuData';
import { execMkdir, validateCodeDir } from '../common/util'; import { execMkdir, validateCodeDir } from '../common/util';
...@@ -30,8 +30,6 @@ import { ...@@ -30,8 +30,6 @@ import {
} from './remoteMachineData'; } from './remoteMachineData';
import { RemoteMachineJobRestServer } from './remoteMachineJobRestServer'; import { RemoteMachineJobRestServer } from './remoteMachineJobRestServer';
interface FlattenRemoteConfig extends ExperimentConfig, RemoteConfig { }
/** /**
* Training Service implementation for Remote Machine (Linux) * Training Service implementation for Remote Machine (Linux)
*/ */
...@@ -53,9 +51,9 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -53,9 +51,9 @@ class RemoteMachineTrainingService implements TrainingService {
private versionCheck: boolean = true; private versionCheck: boolean = true;
private logCollection: string = 'none'; private logCollection: string = 'none';
private sshConnectionPromises: any[]; private sshConnectionPromises: any[];
private config: FlattenRemoteConfig; private config: RemoteConfig;
constructor(config: ExperimentConfig) { constructor(config: RemoteConfig) {
this.metricsEmitter = new EventEmitter(); this.metricsEmitter = new EventEmitter();
this.trialJobsMap = new Map<string, RemoteMachineTrialJobDetail>(); this.trialJobsMap = new Map<string, RemoteMachineTrialJobDetail>();
this.trialExecutorManagerMap = new Map<string, ExecutorManager>(); this.trialExecutorManagerMap = new Map<string, ExecutorManager>();
...@@ -67,7 +65,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -67,7 +65,7 @@ class RemoteMachineTrainingService implements TrainingService {
this.timer = component.get(ObservableTimer); this.timer = component.get(ObservableTimer);
this.log = getLogger('RemoteMachineTrainingService'); this.log = getLogger('RemoteMachineTrainingService');
this.log.info('Construct remote machine training service.'); this.log.info('Construct remote machine training service.');
this.config = flattenConfig(config, 'remote'); this.config = config;
if (!fs.lstatSync(this.config.trialCodeDirectory).isDirectory()) { if (!fs.lstatSync(this.config.trialCodeDirectory).isDirectory()) {
throw new Error(`codeDir ${this.config.trialCodeDirectory} is not a directory`); throw new Error(`codeDir ${this.config.trialCodeDirectory} is not a directory`);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment