Unverified Commit d5857823 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Config refactor (#4370)

parent cb090e8c
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from nni.experiment.config.base import ConfigBase
# config classes
@dataclass(init=False)
class NestedChild(ConfigBase):
msg: str
int_field: int = 1
def _canonicalize(self, parents):
if '/' not in self.msg:
self.msg = parents[0].msg + '/' + self.msg
super()._canonicalize(parents)
def _validate_canonical(self):
super()._validate_canonical()
if not self.msg.endswith('[2]'):
raise ValueError('not end with [2]')
@dataclass(init=False)
class Child(ConfigBase):
msg: str
children: List[NestedChild]
def _canonicalize(self, parents):
if '/' not in self.msg:
self.msg = parents[0].msg + '/' + self.msg
super()._canonicalize(parents)
def _validate_canonical(self):
super()._validate_canonical()
if not self.msg.endswith('[1]'):
raise ValueError('not end with "[1]"')
@dataclass(init=False)
class TestConfig(ConfigBase):
msg: str
required_field: Optional[int]
optional_field: Optional[int] = None
multi_type_field: Union[int, List[int]]
child: Optional[Child] = None
def _canonicalize(self, parents):
if isinstance(self.multi_type_field, int):
self.multi_type_field = [self.multi_type_field]
super()._canonicalize(parents)
# sample inputs
good = {
'msg': 'a',
'required_field': 10,
'multi_type_field': 20,
'child': {
'msg': 'b[1]',
'children': [{
'msg': 'c[2]',
'int_field': 30,
}, {
'msg': 'd[2]',
}],
},
}
missing = deepcopy(good)
missing.pop('required_field')
wrong_type = deepcopy(good)
wrong_type['optional_field'] = 0.5
nested_wrong_type = deepcopy(good)
nested_wrong_type['child']['children'][1]['int_field'] = 'str'
bad_value = deepcopy(good)
bad_value['child']['msg'] = 'b'
extra_field = deepcopy(good)
extra_field['hello'] = 'world'
bads = {
'missing': missing,
'wrong_type': wrong_type,
'nested_wrong_type': nested_wrong_type,
'bad_value': bad_value,
'extra_field': extra_field,
}
# ground truth
_nested_child_1 = NestedChild()
_nested_child_1.msg = 'c[2]'
_nested_child_1.int_field = 30
_nested_child_2 = NestedChild()
_nested_child_2.msg = 'd[2]'
_nested_child_2.int_field = 1
_child = Child()
_child.msg = 'b[1]'
_child.children = [_nested_child_1, _nested_child_2]
good_config = TestConfig()
good_config.msg = 'a'
good_config.required_field = 10
good_config.optional_field = None
good_config.multi_type_field = 20
good_config.child = _child
_nested_child_1 = NestedChild()
_nested_child_1.msg = 'a/b[1]/c[2]'
_nested_child_1.int_field = 30
_nested_child_2 = NestedChild()
_nested_child_2.msg = 'a/b[1]/d[2]'
_nested_child_2.int_field = 1
_child = Child()
_child.msg = 'a/b[1]'
_child.children = [_nested_child_1, _nested_child_2]
good_canon_config = TestConfig()
good_canon_config.msg = 'a'
good_canon_config.required_field = 10
good_canon_config.optional_field = None
good_canon_config.multi_type_field = [20]
good_canon_config.child = _child
# test function
def test_good():
config = TestConfig(**good)
assert config == good_config
config.validate()
assert config.json() == good_canon_config.json()
def test_bad():
for tag, bad in bads.items():
exc = None
try:
config = TestConfig(**bad)
config.validate()
except Exception as e:
exc = e
assert isinstance(exc, ValueError), tag
if __name__ == '__main__':
test_good()
test_bad()
import os.path
from pathlib import Path
from nni.experiment.config import ExperimentConfig
def expand_path(path):
return os.path.realpath(os.path.join(os.path.dirname(__file__), path))
## minimal config ##
minimal_json = {
'searchSpace': {'a': 1},
'trialCommand': 'python main.py',
'trialConcurrency': 2,
'tuner': {
'name': 'random',
},
'trainingService': {
'platform': 'local',
},
}
minimal_class = ExperimentConfig('local')
minimal_class.search_space = {'a': 1}
minimal_class.trial_command = 'python main.py'
minimal_class.trial_concurrency = 2
minimal_class.tuner.name = 'random'
minimal_canon = {
'searchSpace': {'a': 1},
'trialCommand': 'python main.py',
'trialCodeDirectory': os.path.realpath('.'),
'trialConcurrency': 2,
'useAnnotation': False,
'debug': False,
'logLevel': 'info',
'experimentWorkingDirectory': str(Path.home() / 'nni-experiments'),
'tuner': {'name': 'random'},
'trainingService': {
'platform': 'local',
'trialCommand': 'python main.py',
'trialCodeDirectory': os.path.realpath('.'),
'debug': False,
'maxTrialNumberPerGpu': 1,
'reuseMode': False,
},
}
## detailed config ##
detailed_canon = {
'experimentName': 'test case',
'searchSpaceFile': expand_path('assets/search_space.json'),
'searchSpace': {'a': 1},
'trialCommand': 'python main.py',
'trialCodeDirectory': expand_path('assets'),
'trialConcurrency': 2,
'trialGpuNumber': 1,
'maxExperimentDuration': '1.5h',
'maxTrialNumber': 10,
'maxTrialDuration': 60,
'nniManagerIp': '1.2.3.4',
'useAnnotation': False,
'debug': True,
'logLevel': 'warning',
'experimentWorkingDirectory': str(Path.home() / 'nni-experiments'),
'tunerGpuIndices': [0],
'assessor': {
'name': 'assess',
},
'advisor': {
'className': 'Advisor',
'codeDirectory': expand_path('assets'),
'classArgs': {'random_seed': 0},
},
'trainingService': {
'platform': 'local',
'trialCommand': 'python main.py',
'trialCodeDirectory': expand_path('assets'),
'trialGpuNumber': 1,
'debug': True,
'useActiveGpu': False,
'maxTrialNumberPerGpu': 2,
'gpuIndices': [1, 2],
'reuseMode': True,
},
'sharedStorage': {
'storageType': 'NFS',
'localMountPoint': expand_path('assets'),
'remoteMountPoint': '/tmp',
'localMounted': 'usermount',
'nfsServer': 'nfs.test.case',
'exportedDirectory': 'root',
},
}
## test function ##
def test_all():
minimal = ExperimentConfig(**minimal_json)
assert minimal.json() == minimal_canon
assert minimal_class.json() == minimal_canon
detailed = ExperimentConfig.load(expand_path('assets/config.yaml'))
assert detailed.json() == detailed_canon
if __name__ == '__main__':
test_all()
import os.path
from pathlib import Path
from nni.experiment.config import ExperimentConfig, AlgorithmConfig, RemoteConfig, RemoteMachineConfig
## minimal config ##
minimal_json = {
'searchSpace': {'a': 1},
'trialCommand': 'python main.py',
'trialConcurrency': 2,
'tuner': {
'name': 'random',
},
'trainingService': {
'platform': 'remote',
'machine_list': [
{
'host': '1.2.3.4',
'user': 'test_user',
'password': '123456',
},
],
},
}
minimal_class = ExperimentConfig(
search_space = {'a': 1},
trial_command = 'python main.py',
trial_concurrency = 2,
tuner = AlgorithmConfig(
name = 'random',
),
training_service = RemoteConfig(
machine_list = [
RemoteMachineConfig(
host = '1.2.3.4',
user = 'test_user',
password = '123456',
),
],
),
)
minimal_canon = {
'searchSpace': {'a': 1},
'trialCommand': 'python main.py',
'trialCodeDirectory': os.path.realpath('.'),
'trialConcurrency': 2,
'useAnnotation': False,
'debug': False,
'logLevel': 'info',
'experimentWorkingDirectory': str(Path.home() / 'nni-experiments'),
'tuner': {
'name': 'random',
},
'trainingService': {
'platform': 'remote',
'trialCommand': 'python main.py',
'trialCodeDirectory': os.path.realpath('.'),
'debug': False,
'machineList': [
{
'host': '1.2.3.4',
'port': 22,
'user': 'test_user',
'password': '123456',
'useActiveGpu': False,
'maxTrialNumberPerGpu': 1,
}
],
'reuseMode': True,
}
}
## detailed config ##
detailed_json = {
'searchSpace': {'a': 1},
'trialCommand': 'python main.py',
'trialConcurrency': 2,
'trialGpuNumber': 1,
'nni_manager_ip': '1.2.3.0',
'tuner': {
'name': 'random',
},
'trainingService': {
'platform': 'remote',
'machine_list': [
{
'host': '1.2.3.4',
'user': 'test_user',
'password': '123456',
},
{
'host': '1.2.3.5',
'user': 'test_user_2',
'password': 'abcdef',
'use_active_gpu': True,
'max_trial_number_per_gpu': 2,
'gpu_indices': '0,1',
'python_path': '~/path', # don't do this in actual experiment
},
],
},
}
detailed_canon = {
'searchSpace': {'a': 1},
'trialCommand': 'python main.py',
'trialCodeDirectory': os.path.realpath('.'),
'trialConcurrency': 2,
'trialGpuNumber': 1,
'nniManagerIp': '1.2.3.0',
'useAnnotation': False,
'debug': False,
'logLevel': 'info',
'experimentWorkingDirectory': str(Path.home() / 'nni-experiments'),
'tuner': {'name': 'random'},
'trainingService': {
'platform': 'remote',
'trialCommand': 'python main.py',
'trialCodeDirectory': os.path.realpath('.'),
'trialGpuNumber': 1,
'nniManagerIp': '1.2.3.0',
'debug': False,
'machineList': [
{
'host': '1.2.3.4',
'port': 22,
'user': 'test_user',
'password': '123456',
'useActiveGpu': False,
'maxTrialNumberPerGpu': 1
},
{
'host': '1.2.3.5',
'port': 22,
'user': 'test_user_2',
'password': 'abcdef',
'useActiveGpu': True,
'maxTrialNumberPerGpu': 2,
'gpuIndices': [0, 1],
'pythonPath': '~/path'
}
],
'reuseMode': True,
}
}
## test function ##
def test_remote():
config = ExperimentConfig(**minimal_json)
assert config.json() == minimal_canon
assert minimal_class.json() == minimal_canon
config = ExperimentConfig(**detailed_json)
assert config.json() == detailed_canon
if __name__ == '__main__':
test_remote()
......@@ -8,16 +8,27 @@ import { KubernetesStorageKind } from '../training_service/kubernetes/kubernetes
export interface TrainingServiceConfig {
platform: string;
trialCommand: string;
trialCodeDirectory: string;
trialGpuNumber?: number;
nniManagerIp?: string;
// FIXME
// "debug" is only used by openpai to decide whether to check remote nni version
// it should be better to check when local nni version is not "dev"
// it should be even better to check version before launching the experiment and let user to confirm
// log level is currently handled by global logging module and has nothing to do with this
debug?: boolean;
}
/* Local */
export interface LocalConfig extends TrainingServiceConfig {
platform: 'local';
reuseMode: boolean;
useActiveGpu?: boolean;
maxTrialNumberPerGpu: number;
gpuIndices?: number[];
reuseMode: boolean;
}
/* Remote */
......@@ -37,8 +48,8 @@ export interface RemoteMachineConfig {
export interface RemoteConfig extends TrainingServiceConfig {
platform: 'remote';
reuseMode: boolean;
machineList: RemoteMachineConfig[];
reuseMode: boolean;
}
/* OpenPAI */
......@@ -52,11 +63,11 @@ export interface OpenpaiConfig extends TrainingServiceConfig {
trialMemorySize: string;
storageConfigName: string;
dockerImage: string;
virtualCluster?: string;
localStorageMountPoint: string;
containerStorageMountPoint: string;
reuseMode: boolean;
openpaiConfig?: object;
virtualCluster?: string;
}
/* AML */
......@@ -89,10 +100,8 @@ export interface DlcConfig extends TrainingServiceConfig {
}
/* Kubeflow */
// FIXME: merge with shared storage config
export interface KubernetesStorageConfig {
storageType: string;
maxTrialNumberPerGpu?: number;
server?: string;
path?: string;
azureAccount?: string;
......@@ -103,51 +112,51 @@ export interface KubernetesStorageConfig {
export interface KubeflowRoleConfig {
replicas: number;
codeDirectory: string;
command: string;
gpuNumber: number;
cpuNumber: number;
memorySize: number;
memorySize: string | number;
dockerImage: string;
codeDirectory: string;
privateRegistryAuthPath?: string;
}
export interface KubeflowConfig extends TrainingServiceConfig {
platform: 'kubeflow';
ps?: KubeflowRoleConfig;
master?: KubeflowRoleConfig;
worker?: KubeflowRoleConfig;
maxTrialNumberPerGpu: number;
operator: KubeflowOperator;
apiVersion: OperatorApiVersion;
storage: KubernetesStorageConfig;
worker?: KubeflowRoleConfig;
ps?: KubeflowRoleConfig;
master?: KubeflowRoleConfig;
reuseMode: boolean;
maxTrialNumberPerGpu?: number;
}
export interface FrameworkControllerTaskRoleConfig {
name: string;
dockerImage: string;
taskNumber: number;
command: string;
gpuNumber: number;
cpuNumber: number;
memorySize: number;
dockerImage: string;
privateRegistryAuthPath?: string;
memorySize: string | number;
frameworkAttemptCompletionPolicy: {
minFailedTaskCount: number;
minSucceedTaskCount: number;
};
privateRegistryAuthPath?: string;
}
export interface FrameworkControllerConfig extends TrainingServiceConfig {
platform: 'frameworkcontroller';
taskRoles: FrameworkControllerTaskRoleConfig[];
maxTrialNumberPerGpu: number;
storage: KubernetesStorageConfig;
reuseMode: boolean;
namespace: 'default';
apiVersion: string;
serviceAccountName: string;
taskRoles: FrameworkControllerTaskRoleConfig[];
reuseMode: boolean;
maxTrialNumberPerGpu?: number;
namespace?: 'default';
apiVersion?: string;
}
/* shared storage */
......@@ -182,16 +191,17 @@ export interface AlgorithmConfig {
export interface ExperimentConfig {
experimentName?: string;
// searchSpaceFile (handled in python part)
searchSpace: any;
trialCommand: string;
trialCodeDirectory: string;
trialConcurrency: number;
trialGpuNumber?: number;
maxExperimentDuration?: string;
maxTrialDuration?: string;
maxExperimentDuration?: string | number;
maxTrialNumber?: number;
maxTrialDuration?: string | number;
nniManagerIp?: string;
//useAnnotation: boolean; // dealed inside nnictl
// useAnnotation (handled in python part)
debug: boolean;
logLevel?: string;
experimentWorkingDirectory?: string;
......@@ -207,45 +217,31 @@ export interface ExperimentConfig {
/* util functions */
const timeUnits = { d: 24 * 3600, h: 3600, m: 60, s: 1 };
const sizeUnits = { tb: 1024 ** 4, gb: 1024 ** 3, mb: 1024 ** 2, kb: 1024, b: 1 };
export function toSeconds(time: string): number {
for (const [unit, factor] of Object.entries(timeUnits)) {
if (time.toLowerCase().endsWith(unit)) {
const digits = time.slice(0, -1);
return Number(digits) * factor;
function toUnit(value: string | number, targetUnit: string, allUnits: any): number {
if (typeof value === 'number') {
return value;
}
value = value.toLowerCase();
for (const [unit, factor] of Object.entries(allUnits)) {
if (value.endsWith(unit)) {
const digits = value.slice(0, -unit.length);
const num = Number(digits) * (factor as number);
return Math.ceil(num / allUnits[targetUnit]);
}
}
throw new Error(`Bad time string "${time}"`);
throw new Error(`Bad unit in "${value}"`);
}
const sizeUnits = { tb: 1024 * 1024, gb: 1024, mb: 1, kb: 1 / 1024 };
export function toMegaBytes(size: string): number {
for (const [unit, factor] of Object.entries(sizeUnits)) {
if (size.toLowerCase().endsWith(unit)) {
const digits = size.slice(0, -2);
return Math.floor(Number(digits) * factor);
}
}
throw new Error(`Bad size string "${size}"`);
export function toSeconds(time: string | number): number {
return toUnit(time, 's', timeUnits);
}
export function toCudaVisibleDevices(gpuIndices?: number[]): string {
return gpuIndices === undefined ? '' : gpuIndices.join(',');
export function toMegaBytes(size: string | number): number {
return toUnit(size, 'mb', sizeUnits);
}
export function flattenConfig<T>(config: ExperimentConfig, platform: string): T {
const flattened = { };
Object.assign(flattened, config);
if (Array.isArray(config.trainingService)) {
for (const trainingService of config.trainingService) {
if (trainingService.platform === platform) {
Object.assign(flattened, trainingService);
}
}
} else {
assert(config.trainingService.platform === platform);
Object.assign(flattened, config.trainingService);
}
return <T>flattened;
export function toCudaVisibleDevices(gpuIndices?: number[]): string {
return gpuIndices === undefined ? '' : gpuIndices.join(',');
}
......@@ -13,7 +13,7 @@ import {
ExperimentProfile, Manager, ExperimentStatus,
NNIManagerStatus, ProfileUpdateType, TrialJobStatistics
} from '../common/manager';
import { ExperimentConfig, toSeconds, toCudaVisibleDevices } from '../common/experimentConfig';
import { ExperimentConfig, LocalConfig, toSeconds, toCudaVisibleDevices } from '../common/experimentConfig';
import { ExperimentManager } from '../common/experimentManager';
import { TensorboardManager } from '../common/tensorboardManager';
import {
......@@ -454,7 +454,7 @@ class NNIManager implements Manager {
return await module_.RouterTrainingService.construct(config);
} else if (platform === 'local') {
const module_ = await import('../training_service/local/localTrainingService');
return new module_.LocalTrainingService(config);
return new module_.LocalTrainingService(<LocalConfig>config.trainingService);
} else if (platform === 'kubeflow') {
const module_ = await import('../training_service/kubernetes/kubeflow/kubeflowTrainingService');
return new module_.KubeflowTrainingService();
......
......@@ -72,12 +72,12 @@ if (!strPort || strPort.length === 0) {
}
const foregroundArg: string = parseArg(['--foreground', '-f']);
if (!('true' || 'false').includes(foregroundArg.toLowerCase())) {
if (foregroundArg && !['true', 'false'].includes(foregroundArg.toLowerCase())) {
console.log(`FATAL: foreground property should only be true or false`);
usage();
process.exit(1);
}
const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : false;
const foreground: boolean = (foregroundArg && foregroundArg.toLowerCase() === 'true') ? true : false;
const port: number = parseInt(strPort, 10);
......@@ -107,12 +107,12 @@ if (logDir.length > 0) {
const logLevel: string = parseArg(['--log_level', '-ll']);
const readonlyArg: string = parseArg(['--readonly', '-r']);
if (!('true' || 'false').includes(readonlyArg.toLowerCase())) {
if (readonlyArg && !['true', 'false'].includes(readonlyArg.toLowerCase())) {
console.log(`FATAL: readonly property should only be true or false`);
usage();
process.exit(1);
}
const readonly = readonlyArg.toLowerCase() == 'true' ? true : false;
const readonly = (readonlyArg && readonlyArg.toLowerCase() == 'true') ? true : false;
const dispatcherPipe: string = parseArg(['--dispatcher_pipe']);
......
......@@ -36,7 +36,7 @@ describe('Unit test for dataStore', () => {
});
it('test experiment profiles CRUD', async () => {
const profile: ExperimentProfile = {
const profile: ExperimentProfile = <ExperimentProfile>{
params: {
experimentName: 'exp1',
trialConcurrency: 2,
......
......@@ -20,7 +20,7 @@ function startProcess(): void {
const dispatcherCmd: string = getMsgDispatcherCommand(
// Mock tuner config
{
<any>{
experimentName: 'exp1',
maxExperimentDuration: '1h',
searchSpace: '',
......
......@@ -40,7 +40,7 @@ describe('Unit test for nnimanager', function () {
let ClusterMetadataKey = 'mockedMetadataKey';
let experimentParams = {
let experimentParams: any = {
experimentName: 'naive_experiment',
trialConcurrency: 3,
maxExperimentDuration: '5s',
......@@ -86,7 +86,7 @@ describe('Unit test for nnimanager', function () {
debug: true
}
let experimentProfile = {
let experimentProfile: any = {
params: updateExperimentParams,
id: 'test',
execDuration: 0,
......
......@@ -14,7 +14,7 @@ import { ExperimentConfig, ExperimentProfile } from '../../common/manager';
import { cleanupUnitTest, getDefaultDatabaseDir, mkDirP, prepareUnitTest } from '../../common/utils';
import { SqlDB } from '../../core/sqlDatabase';
const expParams1: ExperimentConfig = {
const expParams1: ExperimentConfig = <any>{
experimentName: 'Exp1',
trialConcurrency: 3,
maxExperimentDuration: '100s',
......@@ -31,7 +31,7 @@ const expParams1: ExperimentConfig = {
debug: true
};
const expParams2: ExperimentConfig = {
const expParams2: ExperimentConfig = <any>{
experimentName: 'Exp2',
trialConcurrency: 5,
maxExperimentDuration: '1000s',
......
......@@ -133,7 +133,7 @@ export class MockedNNIManager extends Manager {
throw new MethodNotImplementedError();
}
public getExperimentProfile(): Promise<ExperimentProfile> {
const profile: ExperimentProfile = {
const profile: ExperimentProfile = <any>{
params: {
experimentName: 'exp1',
trialConcurrency: 2,
......
......@@ -13,7 +13,6 @@ import { TrialJobApplicationForm, TrialJobDetail} from '../../common/trainingSer
import { cleanupUnitTest, delay, prepareUnitTest, getExperimentRootDir } from '../../common/utils';
import { TrialConfigMetadataKey } from '../../training_service/common/trialConfigMetadataKey';
import { LocalTrainingService } from '../../training_service/local/localTrainingService';
import { ExperimentConfig } from '../../common/experimentConfig';
// TODO: copy mockedTrail.py to local folder
const localCodeDir: string = tmp.dirSync().name.split('\\').join('\\\\');
......@@ -21,22 +20,22 @@ const mockedTrialPath: string = './test/mock/mockedTrial.py'
fs.copyFileSync(mockedTrialPath, localCodeDir + '/mockedTrial.py')
describe('Unit Test for LocalTrainingService', () => {
const config = <ExperimentConfig>{
const config = <any>{
platform: 'local',
trialCommand: 'sleep 1h && echo hello',
trialCodeDirectory: `${localCodeDir}`,
trialGpuNumber: 0, // TODO: add test case for gpu?
trainingService: {
platform: 'local'
}
maxTrialNumberPerGpu: 1,
reuseMode: true,
};
const config2 = <ExperimentConfig>{
const config2 = <any>{
platform: 'local',
trialCommand: 'python3 mockedTrial.py',
trialCodeDirectory: `${localCodeDir}`,
trialGpuNumber: 0,
trainingService: {
platform: 'local'
}
maxTrialNumberPerGpu: 1,
reuseMode: true,
};
before(() => {
......
......@@ -169,7 +169,7 @@ async function waitEnvironment(waitCount: number,
return waitRequestEnvironment;
}
const config = {
const config: any = {
searchSpace: { },
trialCommand: 'echo hi',
trialCodeDirectory: path.dirname(__filename),
......
......@@ -5,8 +5,7 @@ import assert from 'assert';
import {
AzureStorage, KeyVaultConfig, KubernetesClusterConfig, KubernetesClusterConfigAzure, KubernetesClusterConfigNFS,
KubernetesStorageKind, KubernetesTrialConfig, KubernetesTrialConfigTemplate, NFSConfig, StorageConfig, KubernetesClusterConfigPVC,
PVCConfig,
KubernetesStorageKind, KubernetesTrialConfig, KubernetesTrialConfigTemplate, NFSConfig, StorageConfig
} from '../kubernetesConfig';
export class FrameworkAttemptCompletionPolicy {
......@@ -53,31 +52,6 @@ export class FrameworkControllerClusterConfig extends KubernetesClusterConfig {
}
}
export class FrameworkControllerClusterConfigPVC extends KubernetesClusterConfigPVC {
public readonly serviceAccountName: string;
public readonly configPath: string;
constructor(serviceAccountName: string, apiVersion: string, pvc: PVCConfig, configPath: string,
storage?: KubernetesStorageKind, namespace?: string) {
super(apiVersion, pvc, storage, namespace);
this.serviceAccountName = serviceAccountName;
this.configPath = configPath
}
public static getInstance(jsonObject: object): FrameworkControllerClusterConfigPVC {
const kubernetesClusterConfigObjectPVC: FrameworkControllerClusterConfigPVC = <FrameworkControllerClusterConfigPVC>jsonObject;
assert(kubernetesClusterConfigObjectPVC !== undefined);
return new FrameworkControllerClusterConfigPVC(
kubernetesClusterConfigObjectPVC.serviceAccountName,
kubernetesClusterConfigObjectPVC.apiVersion,
kubernetesClusterConfigObjectPVC.pvc,
kubernetesClusterConfigObjectPVC.configPath,
kubernetesClusterConfigObjectPVC.storage,
kubernetesClusterConfigObjectPVC.namespace
);
}
}
export class FrameworkControllerClusterConfigNFS extends KubernetesClusterConfigNFS {
public readonly serviceAccountName: string;
public readonly configPath?: string;
......@@ -153,8 +127,6 @@ export class FrameworkControllerClusterConfigFactory {
return FrameworkControllerClusterConfigAzure.getInstance(jsonObject);
} else if (storageConfig.storage === undefined || storageConfig.storage === 'nfs') {
return FrameworkControllerClusterConfigNFS.getInstance(jsonObject);
} else if (storageConfig.storage !== undefined && storageConfig.storage === 'pvc') {
return FrameworkControllerClusterConfigPVC.getInstance(jsonObject);
}
throw new Error(`Invalid json object ${jsonObject}`);
}
......
......@@ -26,7 +26,6 @@ import {
FrameworkControllerClusterConfigNFS,
FrameworkControllerTrialConfig,
FrameworkControllerTrialConfigTemplate,
FrameworkControllerClusterConfigPVC,
} from './frameworkcontrollerConfig';
import {FrameworkControllerJobInfoCollector} from './frameworkcontrollerJobInfoCollector';
import {FrameworkControllerJobRestServer} from './frameworkcontrollerJobRestServer';
......@@ -239,19 +238,6 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
nfsFrameworkControllerClusterConfig.nfs.path
);
namespace = nfsFrameworkControllerClusterConfig.namespace
} else if (this.fcClusterConfig.storageType === 'pvc') {
const pvcFrameworkControllerClusterConfig: FrameworkControllerClusterConfigPVC =
<FrameworkControllerClusterConfigPVC>this.fcClusterConfig;
this.fcTemplate = yaml.safeLoad(
fs.readFileSync(
pvcFrameworkControllerClusterConfig.configPath,
'utf8'
)
);
await this.createPVCStorage(
pvcFrameworkControllerClusterConfig.pvc.path
);
namespace = pvcFrameworkControllerClusterConfig.namespace;
}
namespace = namespace ? namespace : "default";
this.kubernetesCRDClient = FrameworkControllerClientFactory.createClient(namespace);
......
......@@ -18,7 +18,7 @@ import {
import {
delay, generateParamFileName, getExperimentRootDir, getJobCancelStatus, getNewLine, isAlive, uniqueString
} from 'common/utils';
import { ExperimentConfig, LocalConfig, flattenConfig } from 'common/experimentConfig';
import { LocalConfig } from 'common/experimentConfig';
import { execMkdir, execNewFile, getScriptName, runScript, setEnvironmentVariable } from '../common/util';
import { GPUScheduler } from './gpuScheduler';
......@@ -73,13 +73,11 @@ class LocalTrialJobDetail implements TrialJobDetail {
}
}
interface FlattenLocalConfig extends ExperimentConfig, LocalConfig { }
/**
* Local machine training service
*/
class LocalTrainingService implements TrainingService {
private readonly config: FlattenLocalConfig;
private readonly config: LocalConfig;
private readonly eventEmitter: EventEmitter;
private readonly jobMap: Map<string, LocalTrialJobDetail>;
private readonly jobQueue: string[];
......@@ -92,8 +90,8 @@ class LocalTrainingService implements TrainingService {
private readonly log: Logger;
private readonly jobStreamMap: Map<string, ts.Stream>;
constructor(config: ExperimentConfig) {
this.config = flattenConfig<FlattenLocalConfig>(config, 'local');
constructor(config: LocalConfig) {
this.config = config;
this.eventEmitter = new EventEmitter();
this.jobMap = new Map<string, LocalTrialJobDetail>();
this.jobQueue = [];
......
......@@ -6,11 +6,9 @@ import { Deferred } from 'ts-deferred';
import { NNIError, NNIErrorNames } from 'common/errors';
import { getLogger, Logger } from 'common/log';
import { TrialJobStatus } from 'common/trainingService';
import { ExperimentConfig, OpenpaiConfig } from 'common/experimentConfig';
import { OpenpaiConfig } from 'common/experimentConfig';
import { PAITrialJobDetail } from './paiConfig';
interface FlattenOpenpaiConfig extends ExperimentConfig, OpenpaiConfig { }
/**
* Collector PAI jobs info from PAI cluster, and update pai job status locally
*/
......@@ -26,7 +24,7 @@ export class PAIJobInfoCollector {
this.finalStatuses = ['SUCCEEDED', 'FAILED', 'USER_CANCELED', 'SYS_CANCELED', 'EARLY_STOPPED'];
}
public async retrieveTrialStatus(protocol: string, token? : string, config?: FlattenOpenpaiConfig): Promise<void> {
public async retrieveTrialStatus(protocol: string, token? : string, config?: OpenpaiConfig): Promise<void> {
if (config === undefined || token === undefined) {
return Promise.resolve();
}
......@@ -42,7 +40,7 @@ export class PAIJobInfoCollector {
await Promise.all(updatePaiTrialJobs);
}
private getSinglePAITrialJobInfo(_protocol: string, paiTrialJob: PAITrialJobDetail, paiToken: string, config: FlattenOpenpaiConfig): Promise<void> {
private getSinglePAITrialJobInfo(_protocol: string, paiTrialJob: PAITrialJobDetail, paiToken: string, config: OpenpaiConfig): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
if (!this.statusesNeedToCheck.includes(paiTrialJob.status)) {
deferred.resolve();
......
......@@ -16,7 +16,7 @@ import {
TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from 'common/trainingService';
import { delay } from 'common/utils';
import { ExperimentConfig, OpenpaiConfig, flattenConfig, toMegaBytes } from 'common/experimentConfig';
import { OpenpaiConfig, toMegaBytes } from 'common/experimentConfig';
import { PAIJobInfoCollector } from './paiJobInfoCollector';
import { PAIJobRestServer } from './paiJobRestServer';
import { PAITrialJobDetail, PAI_TRIAL_COMMAND_FORMAT } from './paiConfig';
......@@ -27,8 +27,6 @@ import { execMkdir, validateCodeDir, execCopydir } from '../common/util';
const yaml = require('js-yaml');
interface FlattenOpenpaiConfig extends ExperimentConfig, OpenpaiConfig { }
/**
* Training Service implementation for OpenPAI (Open Platform for AI)
* Refer https://github.com/Microsoft/pai for more info about OpenPAI
......@@ -55,9 +53,9 @@ class PAITrainingService implements TrainingService {
private copyExpCodeDirPromise?: Promise<void>;
private paiJobConfig: any;
private nniVersion: string | undefined;
private config: FlattenOpenpaiConfig;
private config: OpenpaiConfig;
constructor(config: ExperimentConfig) {
constructor(config: OpenpaiConfig) {
this.log = getLogger('PAITrainingService');
this.metricsEmitter = new EventEmitter();
this.trialJobsMap = new Map<string, PAITrialJobDetail>();
......@@ -67,7 +65,7 @@ class PAITrainingService implements TrainingService {
this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap);
this.paiTokenUpdateInterval = 7200000; //2hours
this.log.info('Construct paiBase training service.');
this.config = flattenConfig(config, 'openpai');
this.config = config;
this.versionCheck = !this.config.debug;
this.paiJobRestServer = new PAIJobRestServer(this);
this.paiToken = this.config.token;
......@@ -221,13 +219,6 @@ class PAITrainingService implements TrainingService {
protected async statusCheckingLoop(): Promise<void> {
while (!this.stopping) {
if (this.config.deprecated && this.config.deprecated.password) {
try {
await this.updatePaiToken();
} catch (error) {
this.log.error(`${error}`);
}
}
await this.paiJobCollector.retrieveTrialStatus(this.protocol, this.paiToken, this.config);
if (this.paiJobRestServer === undefined) {
throw new Error('paiBaseJobRestServer not implemented!');
......@@ -239,55 +230,6 @@ class PAITrainingService implements TrainingService {
}
}
/**
* Update pai token by the interval time or initialize the pai token
*/
protected async updatePaiToken(): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
const currentTime: number = new Date().getTime();
//If pai token initialized and not reach the interval time, do not update
if (this.paiTokenUpdateTime !== undefined && (currentTime - this.paiTokenUpdateTime) < this.paiTokenUpdateInterval) {
return Promise.resolve();
}
const authenticationReq: request.Options = {
uri: `${this.config.host}/rest-server/api/v1/token`,
method: 'POST',
json: true,
body: {
username: this.config.username,
password: this.config.deprecated.password
}
};
request(authenticationReq, (error: Error, response: request.Response, body: any) => {
if (error !== undefined && error !== null) {
this.log.error(`Get PAI token failed: ${error.message}`);
deferred.reject(new Error(`Get PAI token failed: ${error.message}`));
} else {
if (response.statusCode !== 200) {
this.log.error(`Get PAI token failed: get PAI Rest return code ${response.statusCode}`);
deferred.reject(new Error(`Get PAI token failed: ${response.body}, please check paiConfig username or password`));
}
this.paiToken = body.token;
this.paiTokenUpdateTime = new Date().getTime();
deferred.resolve();
}
});
let timeoutId: NodeJS.Timer;
const timeoutDelay: Promise<void> = new Promise<void>((_resolve: Function, reject: Function): void => {
// Set timeout and reject the promise once reach timeout (5 seconds)
timeoutId = setTimeout(
() => reject(new Error('Get PAI token timeout. Please check your PAI cluster.')),
5000);
});
return Promise.race([timeoutDelay, deferred.promise])
.finally(() => { clearTimeout(timeoutId); });
}
public async setClusterMetadata(_key: string, _value: string): Promise<void> { return; }
public async getClusterMetadata(_key: string): Promise<string> { return ''; }
......
......@@ -20,7 +20,7 @@ import {
delay, generateParamFileName, getExperimentRootDir, getIPV4Address, getJobCancelStatus,
getVersion, uniqueString
} from 'common/utils';
import { ExperimentConfig, RemoteConfig, RemoteMachineConfig, flattenConfig } from 'common/experimentConfig';
import { RemoteConfig, RemoteMachineConfig } from 'common/experimentConfig';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { GPUSummary, ScheduleResultType } from '../common/gpuData';
import { execMkdir, validateCodeDir } from '../common/util';
......@@ -30,8 +30,6 @@ import {
} from './remoteMachineData';
import { RemoteMachineJobRestServer } from './remoteMachineJobRestServer';
interface FlattenRemoteConfig extends ExperimentConfig, RemoteConfig { }
/**
* Training Service implementation for Remote Machine (Linux)
*/
......@@ -53,9 +51,9 @@ class RemoteMachineTrainingService implements TrainingService {
private versionCheck: boolean = true;
private logCollection: string = 'none';
private sshConnectionPromises: any[];
private config: FlattenRemoteConfig;
private config: RemoteConfig;
constructor(config: ExperimentConfig) {
constructor(config: RemoteConfig) {
this.metricsEmitter = new EventEmitter();
this.trialJobsMap = new Map<string, RemoteMachineTrialJobDetail>();
this.trialExecutorManagerMap = new Map<string, ExecutorManager>();
......@@ -67,7 +65,7 @@ class RemoteMachineTrainingService implements TrainingService {
this.timer = component.get(ObservableTimer);
this.log = getLogger('RemoteMachineTrainingService');
this.log.info('Construct remote machine training service.');
this.config = flattenConfig(config, 'remote');
this.config = config;
if (!fs.lstatSync(this.config.trialCodeDirectory).isDirectory()) {
throw new Error(`codeDir ${this.config.trialCodeDirectory} is not a directory`);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment