"docs/source/Tutorial/Tensorboard.rst" did not exist on "1418a366bc537ce9166b3e18cdbedd84ad406f8c"
Unverified Commit 277e63f2 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Support 3rd-party training service (#3662)

parent e349b440
...@@ -5,11 +5,15 @@ ...@@ -5,11 +5,15 @@
Miscellaneous utility functions. Miscellaneous utility functions.
""" """
import importlib
import json
import math import math
import os.path import os.path
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Union, List from typing import Any, Dict, Optional, Union, List
import nni.runtime.config
PathLike = Union[Path, str] PathLike = Union[Path, str]
def case_insensitive(key_or_kwargs: Union[str, Dict[str, Any]]) -> Union[str, Dict[str, Any]]: def case_insensitive(key_or_kwargs: Union[str, Dict[str, Any]]) -> Union[str, Dict[str, Any]]:
...@@ -34,6 +38,14 @@ def training_service_config_factory( ...@@ -34,6 +38,14 @@ def training_service_config_factory(
config: Union[List, Dict] = None, config: Union[List, Dict] = None,
base_path: Optional[Path] = None): # -> TrainingServiceConfig base_path: Optional[Path] = None): # -> TrainingServiceConfig
from .common import TrainingServiceConfig from .common import TrainingServiceConfig
# import all custom config classes so they can be found in TrainingServiceConfig.__subclasses__()
custom_ts_config_path = nni.runtime.config.get_config_file('training_services.json')
custom_ts_config = json.load(custom_ts_config_path.open())
for custom_ts_pkg in custom_ts_config.keys():
pkg = importlib.import_module(custom_ts_pkg)
_config_class = pkg.nni_training_service_info.config_class
ts_configs = [] ts_configs = []
if platform is not None: if platform is not None:
assert config is None assert config is None
...@@ -42,7 +54,8 @@ def training_service_config_factory( ...@@ -42,7 +54,8 @@ def training_service_config_factory(
if cls.platform in platforms: if cls.platform in platforms:
ts_configs.append(cls()) ts_configs.append(cls())
if len(ts_configs) < len(platforms): if len(ts_configs) < len(platforms):
raise RuntimeError('There is unrecognized platform!') bad = ', '.join(set(platforms) - set(ts_configs))
raise RuntimeError(f'Bad training service platform: {bad}')
else: else:
assert config is not None assert config is not None
supported_platforms = {cls.platform: cls for cls in TrainingServiceConfig.__subclasses__()} supported_platforms = {cls.platform: cls for cls in TrainingServiceConfig.__subclasses__()}
......
...@@ -9,7 +9,5 @@ if trial_env_vars.NNI_PLATFORM is None: ...@@ -9,7 +9,5 @@ if trial_env_vars.NNI_PLATFORM is None:
from .standalone import * from .standalone import *
elif trial_env_vars.NNI_PLATFORM == 'unittest': elif trial_env_vars.NNI_PLATFORM == 'unittest':
from .test import * from .test import *
elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'dlts', 'aml', 'adl', 'hybrid'):
from .local import *
else: else:
raise RuntimeError('Unknown platform %s' % trial_env_vars.NNI_PLATFORM) from .local import *
...@@ -16,6 +16,8 @@ from .nnictl_utils import stop_experiment, trial_ls, trial_kill, list_experiment ...@@ -16,6 +16,8 @@ from .nnictl_utils import stop_experiment, trial_ls, trial_kill, list_experiment
save_experiment, load_experiment save_experiment, load_experiment
from .algo_management import algo_reg, algo_unreg, algo_show, algo_list from .algo_management import algo_reg, algo_unreg, algo_show, algo_list
from .constants import DEFAULT_REST_PORT from .constants import DEFAULT_REST_PORT
from .import ts_management
init(autoreset=True) init(autoreset=True)
if os.environ.get('COVERAGE_PROCESS_START'): if os.environ.get('COVERAGE_PROCESS_START'):
...@@ -242,6 +244,22 @@ def parse_args(): ...@@ -242,6 +244,22 @@ def parse_args():
parser_algo_list = parser_algo_subparsers.add_parser('list', help='list registered algorithms') parser_algo_list = parser_algo_subparsers.add_parser('list', help='list registered algorithms')
parser_algo_list.set_defaults(func=algo_list) parser_algo_list.set_defaults(func=algo_list)
#parse trainingservice command
parser_ts = subparsers.add_parser('trainingservice', help='control training service')
# add subparsers for parser_ts
parser_ts_subparsers = parser_ts.add_subparsers()
parser_ts_reg = parser_ts_subparsers.add_parser('register', help='register training service')
parser_ts_reg.add_argument('--package', dest='package', help='package name', required=True)
parser_ts_reg.set_defaults(func=ts_management.register)
parser_ts_unreg = parser_ts_subparsers.add_parser('unregister', help='unregister training service')
parser_ts_unreg.add_argument('--package', dest='package', help='package name', required=True)
parser_ts_unreg.set_defaults(func=ts_management.unregister)
parser_ts_list = parser_ts_subparsers.add_parser('list', help='list custom training services')
parser_ts_list.set_defaults(func=ts_management.list_services)
# To show message that nnictl package command is replaced by nnictl algo, to be remove in the future release. # To show message that nnictl package command is replaced by nnictl algo, to be remove in the future release.
def show_messsage_for_nnictl_package(args): def show_messsage_for_nnictl_package(args):
print_error('nnictl package command is replaced by nnictl algo, please run nnictl algo -h to show the usage') print_error('nnictl package command is replaced by nnictl algo, please run nnictl algo -h to show the usage')
......
import importlib
import json
from nni.runtime.config import get_config_file
from .common_utils import print_error, print_green
_builtin_training_services = [
'local',
'remote',
'openpai', 'pai',
'aml',
'kubeflow',
'frameworkcontroller',
'adl',
]
def register(args):
if args.package in _builtin_training_services:
print_error(f'{args.package} is a builtin training service')
return
try:
module = importlib.import_module(args.package)
except Exception:
print_error(f'Cannot import package {args.package}')
return
try:
info = module.nni_training_service_info
except Exception:
print_error(f'Cannot read nni_training_service_info from {args.package}')
return
try:
info.config_class()
except Exception:
print_error('Bad experiment config class')
return
try:
service_config = {
'node_module_path': info.node_module_path,
'node_class_name': info.node_class_name,
}
json.dumps(service_config)
except Exception:
print_error('Bad node_module_path or bad node_class_name')
return
config = _load()
update = args.package in config
config[args.package] = service_config
_save(config)
if update:
print_green(f'Sucessfully updated {args.package}')
else:
print_green(f'Sucessfully registered {args.package}')
def unregister(args):
config = _load()
if args.package not in config:
print_error(f'{args.package} is not a registered training service')
return
config.pop(args.package, None)
_save(config)
print_green(f'Sucessfully unregistered {args.package}')
def list_services(_):
print('\n'.join(_load().keys()))
def _load():
return json.load(get_config_file('training_services.json').open())
def _save(config):
json.dump(config, get_config_file('training_services.json').open('w'), indent=4)
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import * as fs from 'fs'; import * as fs from 'fs';
import { Writable } from 'stream'; import { Writable } from 'stream';
import * as util from 'util';
/* log level constants */ /* log level constants */
...@@ -28,7 +29,6 @@ const levelNames = new Map<number, string>([ ...@@ -28,7 +29,6 @@ const levelNames = new Map<number, string>([
/* global_ states */ /* global_ states */
let logFile: Writable | null = null;
let logLevel: number = 0; let logLevel: number = 0;
const loggers = new Map<string, Logger>(); const loggers = new Map<string, Logger>();
...@@ -70,7 +70,8 @@ export class Logger { ...@@ -70,7 +70,8 @@ export class Logger {
} }
private log(level: number, args: any[]): void { private log(level: number, args: any[]): void {
if (level < logLevel || logFile === null) { const logFile: Writable | undefined = (global as any).logFile;
if (level < logLevel || logFile === undefined) {
return; return;
} }
...@@ -80,20 +81,7 @@ export class Logger { ...@@ -80,20 +81,7 @@ export class Logger {
const levelName = levelNames.has(level) ? levelNames.get(level) : level.toString(); const levelName = levelNames.has(level) ? levelNames.get(level) : level.toString();
const words = []; const message = args.map(arg => (typeof arg === 'string' ? arg : util.inspect(arg))).join(' ');
for (const arg of args) {
if (arg === undefined) {
words.push('undefined');
} else if (arg === null) {
words.push('null');
} else if (typeof arg === 'object') {
const json = JSON.stringify(arg);
words.push(json === undefined ? arg : json);
} else {
words.push(arg);
}
}
const message = words.join(' ');
const record = `[${time}] ${levelName} (${this.name}) ${message}\n`; const record = `[${time}] ${levelName} (${this.name}) ${message}\n`;
logFile.write(record); logFile.write(record);
...@@ -124,7 +112,7 @@ export function setLogLevel(levelName: string): void { ...@@ -124,7 +112,7 @@ export function setLogLevel(levelName: string): void {
} }
export function startLogging(logPath: string): void { export function startLogging(logPath: string): void {
logFile = fs.createWriteStream(logPath, { (global as any).logFile = fs.createWriteStream(logPath, {
flags: 'a+', flags: 'a+',
encoding: 'utf8', encoding: 'utf8',
autoClose: true autoClose: true
...@@ -132,8 +120,8 @@ export function startLogging(logPath: string): void { ...@@ -132,8 +120,8 @@ export function startLogging(logPath: string): void {
} }
export function stopLogging(): void { export function stopLogging(): void {
if (logFile !== null) { if ((global as any).logFile !== undefined) {
logFile.end(); (global as any).logFile.end();
logFile = null; (global as any).logFile = undefined;
} }
} }
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as fs from 'fs';
import * as path from 'path';
import { promisify } from 'util';
import { runPythonScript } from './pythonScript';
export interface CustomEnvironmentServiceConfig {
name: string;
nodeModulePath: string;
nodeClassName: string;
}
const readFile = promisify(fs.readFile);
async function readConfigFile(fileName: string): Promise<string> {
const script = 'import nni.runtime.config ; print(nni.runtime.config.get_config_directory())';
const configDir = (await runPythonScript(script)).trim();
const stream = await readFile(path.join(configDir, fileName));
return stream.toString();
}
export async function getCustomEnvironmentServiceConfig(name: string): Promise<CustomEnvironmentServiceConfig | null> {
const configJson = await readConfigFile('training_services.json');
const config = JSON.parse(configJson);
if (config[name] === undefined) {
return null;
}
return {
name,
nodeModulePath: config[name].nodeModulePath as string,
nodeClassName: config[name].nodeClassName as string,
}
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import { spawn } from 'child_process';
import { Logger, getLogger } from './log';
const python = process.platform === 'win32' ? 'python.exe' : 'python3';
export async function runPythonScript(script: string, logger?: Logger): Promise<string> {
const proc = spawn(python, [ '-c', script ]);
const procPromise = new Promise<void>((resolve, reject) => {
proc.on('error', (err: Error) => { reject(err); });
proc.on('exit', () => { resolve(); });
});
await procPromise;
const stdout = proc.stdout.read().toString();
const stderr = proc.stderr.read().toString();
if (stderr) {
if (logger === undefined) {
logger = getLogger();
}
logger.warning('python script has stderr.');
logger.warning('script:', script);
logger.warning('stderr:', stderr);
}
return stdout;
}
...@@ -25,8 +25,7 @@ import { ExperimentManager } from './experimentManager'; ...@@ -25,8 +25,7 @@ import { ExperimentManager } from './experimentManager';
import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService'; import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService';
function getExperimentRootDir(): string { function getExperimentRootDir(): string {
return getExperimentStartupInfo() return getExperimentStartupInfo().getLogDir();
.getLogDir();
} }
function getLogDir(): string { function getLogDir(): string {
...@@ -34,8 +33,7 @@ function getLogDir(): string { ...@@ -34,8 +33,7 @@ function getLogDir(): string {
} }
function getLogLevel(): string { function getLogLevel(): string {
return getExperimentStartupInfo() return getExperimentStartupInfo().getLogLevel();
.getLogLevel();
} }
function getDefaultDatabaseDir(): string { function getDefaultDatabaseDir(): string {
...@@ -481,6 +479,11 @@ async function getFreePort(host: string, start: number, end: number): Promise<nu ...@@ -481,6 +479,11 @@ async function getFreePort(host: string, start: number, end: number): Promise<nu
} }
} }
export function importModule(modulePath: string): any {
module.paths.unshift(path.dirname(modulePath));
return require(path.basename(modulePath));
}
export { export {
countFilesRecursively, validateFileNameRecursively, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir, getExperimentsInfoPath, countFilesRecursively, validateFileNameRecursively, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir, getExperimentsInfoPath,
getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin, withLockSync, getFreePort, isPortOpen, getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin, withLockSync, getFreePort, isPortOpen,
......
...@@ -445,10 +445,7 @@ class NNIManager implements Manager { ...@@ -445,10 +445,7 @@ class NNIManager implements Manager {
throw new Error('Cannot detect training service platform'); throw new Error('Cannot detect training service platform');
} }
if (['remote', 'pai', 'aml', 'hybrid'].includes(platform)) { if (platform === 'local') {
const module_ = await import('../training_service/reusable/routerTrainingService');
return new module_.RouterTrainingService(config);
} else if (platform === 'local') {
const module_ = await import('../training_service/local/localTrainingService'); const module_ = await import('../training_service/local/localTrainingService');
return new module_.LocalTrainingService(config); return new module_.LocalTrainingService(config);
} else if (platform === 'kubeflow') { } else if (platform === 'kubeflow') {
...@@ -460,6 +457,9 @@ class NNIManager implements Manager { ...@@ -460,6 +457,9 @@ class NNIManager implements Manager {
} else if (platform === 'adl') { } else if (platform === 'adl') {
const module_ = await import('../training_service/kubernetes/adl/adlTrainingService'); const module_ = await import('../training_service/kubernetes/adl/adlTrainingService');
return new module_.AdlTrainingService(); return new module_.AdlTrainingService();
} else {
const module_ = await import('../training_service/reusable/routerTrainingService');
return await module_.RouterTrainingService.construct(config);
} }
throw new Error(`Unsupported training service platform "${platform}"`); throw new Error(`Unsupported training service platform "${platform}"`);
......
...@@ -83,11 +83,6 @@ const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : fals ...@@ -83,11 +83,6 @@ const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : fals
const port: number = parseInt(strPort, 10); const port: number = parseInt(strPort, 10);
const mode: string = parseArg(['--mode', '-m']); const mode: string = parseArg(['--mode', '-m']);
if (!['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'dlts', 'aml', 'adl', 'hybrid'].includes(mode)) {
console.log(`FATAL: unknown mode: ${mode}`);
usage();
process.exit(1);
}
const startMode: string = parseArg(['--start_mode', '-s']); const startMode: string = parseArg(['--start_mode', '-s']);
if (![ExperimentStartUpMode.NEW, ExperimentStartUpMode.RESUME].includes(startMode)) { if (![ExperimentStartUpMode.NEW, ExperimentStartUpMode.RESUME].includes(startMode)) {
......
...@@ -6,9 +6,7 @@ ...@@ -6,9 +6,7 @@
import * as fs from 'fs'; import * as fs from 'fs';
import * as path from 'path'; import * as path from 'path';
import * as component from '../../../common/component'; import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../../common/log'; import { getLogger, Logger } from '../../../common/log';
import { getExperimentRootDir } from '../../../common/utils';
import { ExperimentConfig, AmlConfig, flattenConfig } from '../../../common/experimentConfig'; import { ExperimentConfig, AmlConfig, flattenConfig } from '../../../common/experimentConfig';
import { validateCodeDir } from '../../common/util'; import { validateCodeDir } from '../../common/util';
import { AMLClient } from '../aml/amlClient'; import { AMLClient } from '../aml/amlClient';
...@@ -31,10 +29,10 @@ export class AMLEnvironmentService extends EnvironmentService { ...@@ -31,10 +29,10 @@ export class AMLEnvironmentService extends EnvironmentService {
private experimentRootDir: string; private experimentRootDir: string;
private config: FlattenAmlConfig; private config: FlattenAmlConfig;
constructor(config: ExperimentConfig) { constructor(experimentRootDir: string, experimentId: string, config: ExperimentConfig) {
super(); super();
this.experimentId = getExperimentId(); this.experimentId = experimentId;
this.experimentRootDir = getExperimentRootDir(); this.experimentRootDir = experimentRootDir;
this.config = flattenConfig(config, 'aml'); this.config = flattenConfig(config, 'aml');
validateCodeDir(this.config.trialCodeDirectory); validateCodeDir(this.config.trialCodeDirectory);
} }
......
...@@ -4,20 +4,31 @@ import { LocalEnvironmentService } from './localEnvironmentService'; ...@@ -4,20 +4,31 @@ import { LocalEnvironmentService } from './localEnvironmentService';
import { RemoteEnvironmentService } from './remoteEnvironmentService'; import { RemoteEnvironmentService } from './remoteEnvironmentService';
import { EnvironmentService } from '../environment'; import { EnvironmentService } from '../environment';
import { ExperimentConfig } from '../../../common/experimentConfig'; import { ExperimentConfig } from '../../../common/experimentConfig';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getCustomEnvironmentServiceConfig } from '../../../common/nniConfig';
import { getExperimentRootDir, importModule } from '../../../common/utils';
export class EnvironmentServiceFactory {
public static createEnvironmentService(name: string, config: ExperimentConfig): EnvironmentService { export async function createEnvironmentService(name: string, config: ExperimentConfig): Promise<EnvironmentService> {
switch(name) { const expId = getExperimentId();
case 'local': const rootDir = getExperimentRootDir();
return new LocalEnvironmentService(config);
case 'remote': switch(name) {
return new RemoteEnvironmentService(config); case 'local':
case 'aml': return new LocalEnvironmentService(rootDir, expId, config);
return new AMLEnvironmentService(config); case 'remote':
case 'openpai': return new RemoteEnvironmentService(rootDir, expId, config);
return new OpenPaiEnvironmentService(config); case 'aml':
default: return new AMLEnvironmentService(rootDir, expId, config);
throw new Error(`${name} not supported!`); case 'openpai':
} return new OpenPaiEnvironmentService(rootDir, expId, config);
}
const esConfig = await getCustomEnvironmentServiceConfig(name);
if (esConfig === null) {
throw new Error(`${name} is not a supported training service!`);
} }
const esModule = importModule(esConfig.nodeModulePath);
const esClass = esModule[esConfig.nodeClassName] as any;
return new esClass(rootDir, expId, config);
} }
...@@ -7,11 +7,10 @@ import * as fs from 'fs'; ...@@ -7,11 +7,10 @@ import * as fs from 'fs';
import * as path from 'path'; import * as path from 'path';
import * as tkill from 'tree-kill'; import * as tkill from 'tree-kill';
import * as component from '../../../common/component'; import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../../common/log'; import { getLogger, Logger } from '../../../common/log';
import { ExperimentConfig } from '../../../common/experimentConfig'; import { ExperimentConfig } from '../../../common/experimentConfig';
import { EnvironmentInformation, EnvironmentService } from '../environment'; import { EnvironmentInformation, EnvironmentService } from '../environment';
import { getExperimentRootDir, isAlive, getNewLine } from '../../../common/utils'; import { isAlive, getNewLine } from '../../../common/utils';
import { execMkdir, runScript, getScriptName, execCopydir } from '../../common/util'; import { execMkdir, runScript, getScriptName, execCopydir } from '../../common/util';
import { SharedStorageService } from '../sharedStorage' import { SharedStorageService } from '../sharedStorage'
...@@ -22,10 +21,10 @@ export class LocalEnvironmentService extends EnvironmentService { ...@@ -22,10 +21,10 @@ export class LocalEnvironmentService extends EnvironmentService {
private experimentRootDir: string; private experimentRootDir: string;
private experimentId: string; private experimentId: string;
constructor(_config: ExperimentConfig) { constructor(experimentRootDir: string, experimentId: string, _config: ExperimentConfig) {
super(); super();
this.experimentId = getExperimentId(); this.experimentId = experimentId;
this.experimentRootDir = getExperimentRootDir(); this.experimentRootDir = experimentRootDir;
} }
public get environmentMaintenceLoopInterval(): number { public get environmentMaintenceLoopInterval(): number {
...@@ -110,8 +109,6 @@ export class LocalEnvironmentService extends EnvironmentService { ...@@ -110,8 +109,6 @@ export class LocalEnvironmentService extends EnvironmentService {
const sharedStorageService = component.get<SharedStorageService>(SharedStorageService); const sharedStorageService = component.get<SharedStorageService>(SharedStorageService);
if (environment.useSharedStorage && sharedStorageService.canLocalMounted) { if (environment.useSharedStorage && sharedStorageService.canLocalMounted) {
this.experimentRootDir = sharedStorageService.localWorkingRoot; this.experimentRootDir = sharedStorageService.localWorkingRoot;
} else {
this.experimentRootDir = getExperimentRootDir();
} }
const localEnvCodeFolder: string = path.join(this.experimentRootDir, "envs"); const localEnvCodeFolder: string = path.join(this.experimentRootDir, "envs");
if (environment.useSharedStorage && !sharedStorageService.canLocalMounted) { if (environment.useSharedStorage && !sharedStorageService.canLocalMounted) {
......
...@@ -7,7 +7,6 @@ import * as yaml from 'js-yaml'; ...@@ -7,7 +7,6 @@ import * as yaml from 'js-yaml';
import * as request from 'request'; import * as request from 'request';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import * as component from '../../../common/component'; import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { ExperimentConfig, OpenpaiConfig, flattenConfig, toMegaBytes } from '../../../common/experimentConfig'; import { ExperimentConfig, OpenpaiConfig, flattenConfig, toMegaBytes } from '../../../common/experimentConfig';
import { getLogger, Logger } from '../../../common/log'; import { getLogger, Logger } from '../../../common/log';
import { PAIClusterConfig } from '../../pai/paiConfig'; import { PAIClusterConfig } from '../../pai/paiConfig';
...@@ -32,9 +31,9 @@ export class OpenPaiEnvironmentService extends EnvironmentService { ...@@ -32,9 +31,9 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
private experimentId: string; private experimentId: string;
private config: FlattenOpenpaiConfig; private config: FlattenOpenpaiConfig;
constructor(config: ExperimentConfig) { constructor(_experimentRootDir: string, experimentId: string, config: ExperimentConfig) {
super(); super();
this.experimentId = getExperimentId(); this.experimentId = experimentId;
this.config = flattenConfig(config, 'openpai'); this.config = flattenConfig(config, 'openpai');
this.paiToken = this.config.token; this.paiToken = this.config.token;
this.protocol = this.config.host.toLowerCase().startsWith('https://') ? 'https' : 'http'; this.protocol = this.config.host.toLowerCase().startsWith('https://') ? 'https' : 'http';
......
...@@ -6,10 +6,9 @@ ...@@ -6,10 +6,9 @@
import * as fs from 'fs'; import * as fs from 'fs';
import * as path from 'path'; import * as path from 'path';
import * as component from '../../../common/component'; import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../../common/log'; import { getLogger, Logger } from '../../../common/log';
import { EnvironmentInformation, EnvironmentService } from '../environment'; import { EnvironmentInformation, EnvironmentService } from '../environment';
import { getExperimentRootDir, getLogLevel } from '../../../common/utils'; import { getLogLevel } from '../../../common/utils';
import { ExperimentConfig, RemoteConfig, RemoteMachineConfig, flattenConfig } from '../../../common/experimentConfig'; import { ExperimentConfig, RemoteConfig, RemoteMachineConfig, flattenConfig } from '../../../common/experimentConfig';
import { execMkdir } from '../../common/util'; import { execMkdir } from '../../common/util';
import { ExecutorManager } from '../../remote_machine/remoteMachineData'; import { ExecutorManager } from '../../remote_machine/remoteMachineData';
...@@ -33,14 +32,13 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -33,14 +32,13 @@ export class RemoteEnvironmentService extends EnvironmentService {
private experimentId: string; private experimentId: string;
private config: FlattenRemoteConfig; private config: FlattenRemoteConfig;
constructor(config: ExperimentConfig) { constructor(experimentRootDir: string, experimentId: string, config: ExperimentConfig) {
super(); super();
this.experimentId = getExperimentId(); this.experimentId = experimentId;
this.environmentExecutorManagerMap = new Map<string, ExecutorManager>(); this.environmentExecutorManagerMap = new Map<string, ExecutorManager>();
this.machineExecutorManagerMap = new Map<RemoteMachineConfig, ExecutorManager>(); this.machineExecutorManagerMap = new Map<RemoteMachineConfig, ExecutorManager>();
this.remoteMachineMetaOccupiedMap = new Map<RemoteMachineConfig, boolean>(); this.remoteMachineMetaOccupiedMap = new Map<RemoteMachineConfig, boolean>();
this.experimentRootDir = getExperimentRootDir(); this.experimentRootDir = experimentRootDir;
this.experimentId = getExperimentId();
this.log = getLogger(); this.log = getLogger();
this.config = flattenConfig(config, 'remote'); this.config = flattenConfig(config, 'remote');
...@@ -103,10 +101,10 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -103,10 +101,10 @@ export class RemoteEnvironmentService extends EnvironmentService {
// Create root working directory after executor is ready // Create root working directory after executor is ready
const nniRootDir: string = executor.joinPath(executor.getTempPath(), 'nni-experiments'); const nniRootDir: string = executor.joinPath(executor.getTempPath(), 'nni-experiments');
await executor.createFolder(executor.getRemoteExperimentRootDir(getExperimentId())); await executor.createFolder(executor.getRemoteExperimentRootDir(this.experimentId));
// the directory to store temp scripts in remote machine // the directory to store temp scripts in remote machine
const remoteGpuScriptCollectorDir: string = executor.getRemoteScriptsPath(getExperimentId()); const remoteGpuScriptCollectorDir: string = executor.getRemoteScriptsPath(this.experimentId);
// clean up previous result. // clean up previous result.
await executor.createFolder(remoteGpuScriptCollectorDir, true); await executor.createFolder(remoteGpuScriptCollectorDir, true);
...@@ -245,7 +243,7 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -245,7 +243,7 @@ export class RemoteEnvironmentService extends EnvironmentService {
throw new Error(`Mount shared storage on remote machine failed.\n ERROR: ${result.stderr}`); throw new Error(`Mount shared storage on remote machine failed.\n ERROR: ${result.stderr}`);
} }
} else { } else {
this.remoteExperimentRootDir = executor.getRemoteExperimentRootDir(getExperimentId()); this.remoteExperimentRootDir = executor.getRemoteExperimentRootDir(this.experimentId);
} }
environment.command = await this.getScript(environment); environment.command = await this.getScript(environment);
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
'use strict'; 'use strict';
import * as component from '../../common/component';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { MethodNotImplementedError } from '../../common/errors'; import { MethodNotImplementedError } from '../../common/errors';
import { ExperimentConfig, RemoteConfig, OpenpaiConfig } from '../../common/experimentConfig'; import { ExperimentConfig, RemoteConfig, OpenpaiConfig } from '../../common/experimentConfig';
...@@ -18,23 +17,27 @@ import { TrialDispatcher } from './trialDispatcher'; ...@@ -18,23 +17,27 @@ import { TrialDispatcher } from './trialDispatcher';
* It's a intermedia implementation to support reusable training service. * It's a intermedia implementation to support reusable training service.
* The final goal is to support reusable training job in higher level than training service. * The final goal is to support reusable training job in higher level than training service.
*/ */
@component.Singleton
class RouterTrainingService implements TrainingService { class RouterTrainingService implements TrainingService {
protected readonly log: Logger; private log!: Logger;
private internalTrainingService: TrainingService; private internalTrainingService!: TrainingService;
constructor(config: ExperimentConfig) { public static async construct(config: ExperimentConfig): Promise<RouterTrainingService> {
this.log = getLogger(); const instance = new RouterTrainingService();
instance.log = getLogger('RouterTrainingService');
const platform = Array.isArray(config.trainingService) ? 'hybrid' : config.trainingService.platform; const platform = Array.isArray(config.trainingService) ? 'hybrid' : config.trainingService.platform;
if (platform === 'remote' && !(<RemoteConfig>config.trainingService).reuseMode) { if (platform === 'remote' && !(<RemoteConfig>config.trainingService).reuseMode) {
this.internalTrainingService = new RemoteMachineTrainingService(config); instance.internalTrainingService = new RemoteMachineTrainingService(config);
} else if (platform === 'openpai' && !(<OpenpaiConfig>config.trainingService).reuseMode) { } else if (platform === 'openpai' && !(<OpenpaiConfig>config.trainingService).reuseMode) {
this.internalTrainingService = new PAITrainingService(config); instance.internalTrainingService = new PAITrainingService(config);
} else { } else {
this.internalTrainingService = new TrialDispatcher(config); instance.internalTrainingService = await TrialDispatcher.construct(config);
} }
return instance;
} }
// eslint-disable-next-line @typescript-eslint/no-empty-function
private constructor() { }
public async listTrialJobs(): Promise<TrialJobDetail[]> { public async listTrialJobs(): Promise<TrialJobDetail[]> {
if (this.internalTrainingService === undefined) { if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!"); throw new Error("TrainingService is not assigned!");
......
...@@ -203,7 +203,7 @@ describe('Unit Test for TrialDispatcher', () => { ...@@ -203,7 +203,7 @@ describe('Unit Test for TrialDispatcher', () => {
}); });
beforeEach(async () => { beforeEach(async () => {
trialDispatcher = new TrialDispatcher(config); trialDispatcher = await TrialDispatcher.construct(config);
// set ut environment // set ut environment
let environmentServiceList: EnvironmentService[] = []; let environmentServiceList: EnvironmentService[] = [];
......
...@@ -24,7 +24,7 @@ import { TrialConfig } from '../common/trialConfig'; ...@@ -24,7 +24,7 @@ import { TrialConfig } from '../common/trialConfig';
import { validateCodeDir } from '../common/util'; import { validateCodeDir } from '../common/util';
import { Command, CommandChannel } from './commandChannel'; import { Command, CommandChannel } from './commandChannel';
import { EnvironmentInformation, EnvironmentService, NodeInformation, RunnerSettings, TrialGpuSummary } from './environment'; import { EnvironmentInformation, EnvironmentService, NodeInformation, RunnerSettings, TrialGpuSummary } from './environment';
import { EnvironmentServiceFactory } from './environments/environmentServiceFactory'; import { createEnvironmentService } from './environments/environmentServiceFactory';
import { GpuScheduler } from './gpuScheduler'; import { GpuScheduler } from './gpuScheduler';
import { MountedStorageService } from './storages/mountedStorageService'; import { MountedStorageService } from './storages/mountedStorageService';
import { StorageService } from './storageService'; import { StorageService } from './storageService';
...@@ -39,20 +39,20 @@ import { TrialDetail } from './trial'; ...@@ -39,20 +39,20 @@ import { TrialDetail } from './trial';
**/ **/
@component.Singleton @component.Singleton
class TrialDispatcher implements TrainingService { class TrialDispatcher implements TrainingService {
private readonly log: Logger; private log: Logger;
private readonly isDeveloping: boolean = false; private isDeveloping: boolean = false;
private stopping: boolean = false; private stopping: boolean = false;
private readonly metricsEmitter: EventEmitter; private metricsEmitter: EventEmitter;
private readonly experimentId: string; private experimentId: string;
private readonly experimentRootDir: string; private experimentRootDir: string;
private enableVersionCheck: boolean = true; private enableVersionCheck: boolean = true;
private trialConfig: TrialConfig | undefined; private trialConfig: TrialConfig | undefined;
private readonly trials: Map<string, TrialDetail>; private trials: Map<string, TrialDetail>;
private readonly environments: Map<string, EnvironmentInformation>; private environments: Map<string, EnvironmentInformation>;
// make public for ut // make public for ut
public environmentServiceList: EnvironmentService[] = []; public environmentServiceList: EnvironmentService[] = [];
public commandChannelSet: Set<CommandChannel>; public commandChannelSet: Set<CommandChannel>;
...@@ -82,8 +82,14 @@ class TrialDispatcher implements TrainingService { ...@@ -82,8 +82,14 @@ class TrialDispatcher implements TrainingService {
private config: ExperimentConfig; private config: ExperimentConfig;
constructor(config: ExperimentConfig) { public static async construct(config: ExperimentConfig): Promise<TrialDispatcher> {
this.log = getLogger(); const instance = new TrialDispatcher(config);
await instance.asyncConstructor(config);
return instance;
}
private constructor(config: ExperimentConfig) {
this.log = getLogger('TrialDispatcher');
this.trials = new Map<string, TrialDetail>(); this.trials = new Map<string, TrialDetail>();
this.environments = new Map<string, EnvironmentInformation>(); this.environments = new Map<string, EnvironmentInformation>();
this.metricsEmitter = new EventEmitter(); this.metricsEmitter = new EventEmitter();
...@@ -109,18 +115,14 @@ class TrialDispatcher implements TrainingService { ...@@ -109,18 +115,14 @@ class TrialDispatcher implements TrainingService {
if (this.enableGpuScheduler) { if (this.enableGpuScheduler) {
this.log.info(`TrialDispatcher: GPU scheduler is enabled.`) this.log.info(`TrialDispatcher: GPU scheduler is enabled.`)
} }
}
validateCodeDir(config.trialCodeDirectory); private async asyncConstructor(config: ExperimentConfig): Promise<void> {
await validateCodeDir(config.trialCodeDirectory);
if (Array.isArray(config.trainingService)) { const serviceConfigs = Array.isArray(config.trainingService) ? config.trainingService : [ config.trainingService ];
config.trainingService.forEach(trainingService => { const servicePromises = serviceConfigs.map(serviceConfig => createEnvironmentService(serviceConfig.platform, config));
const env = EnvironmentServiceFactory.createEnvironmentService(trainingService.platform, config); this.environmentServiceList = await Promise.all(servicePromises);
this.environmentServiceList.push(env);
});
} else {
const env = EnvironmentServiceFactory.createEnvironmentService(config.trainingService.platform, config);
this.environmentServiceList.push(env);
}
this.environmentMaintenceLoopInterval = Math.max( this.environmentMaintenceLoopInterval = Math.max(
...this.environmentServiceList.map((env) => env.environmentMaintenceLoopInterval) ...this.environmentServiceList.map((env) => env.environmentMaintenceLoopInterval)
...@@ -132,7 +134,7 @@ class TrialDispatcher implements TrainingService { ...@@ -132,7 +134,7 @@ class TrialDispatcher implements TrainingService {
} }
if (this.config.sharedStorage !== undefined) { if (this.config.sharedStorage !== undefined) {
this.initializeSharedStorage(this.config.sharedStorage); await this.initializeSharedStorage(this.config.sharedStorage);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment