Unverified Commit 277e63f2 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Support 3rd-party training service (#3662)

parent e349b440
......@@ -5,11 +5,15 @@
Miscellaneous utility functions.
"""
import importlib
import json
import math
import os.path
from pathlib import Path
from typing import Any, Dict, Optional, Union, List
import nni.runtime.config
PathLike = Union[Path, str]
def case_insensitive(key_or_kwargs: Union[str, Dict[str, Any]]) -> Union[str, Dict[str, Any]]:
......@@ -34,6 +38,14 @@ def training_service_config_factory(
config: Union[List, Dict] = None,
base_path: Optional[Path] = None): # -> TrainingServiceConfig
from .common import TrainingServiceConfig
# import all custom config classes so they can be found in TrainingServiceConfig.__subclasses__()
custom_ts_config_path = nni.runtime.config.get_config_file('training_services.json')
custom_ts_config = json.load(custom_ts_config_path.open())
for custom_ts_pkg in custom_ts_config.keys():
pkg = importlib.import_module(custom_ts_pkg)
_config_class = pkg.nni_training_service_info.config_class
ts_configs = []
if platform is not None:
assert config is None
......@@ -42,7 +54,8 @@ def training_service_config_factory(
if cls.platform in platforms:
ts_configs.append(cls())
if len(ts_configs) < len(platforms):
raise RuntimeError('There is unrecognized platform!')
bad = ', '.join(set(platforms) - set(ts_configs))
raise RuntimeError(f'Bad training service platform: {bad}')
else:
assert config is not None
supported_platforms = {cls.platform: cls for cls in TrainingServiceConfig.__subclasses__()}
......
......@@ -9,7 +9,5 @@ if trial_env_vars.NNI_PLATFORM is None:
from .standalone import *
elif trial_env_vars.NNI_PLATFORM == 'unittest':
from .test import *
elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'dlts', 'aml', 'adl', 'hybrid'):
from .local import *
else:
raise RuntimeError('Unknown platform %s' % trial_env_vars.NNI_PLATFORM)
from .local import *
......@@ -16,6 +16,8 @@ from .nnictl_utils import stop_experiment, trial_ls, trial_kill, list_experiment
save_experiment, load_experiment
from .algo_management import algo_reg, algo_unreg, algo_show, algo_list
from .constants import DEFAULT_REST_PORT
from .import ts_management
init(autoreset=True)
if os.environ.get('COVERAGE_PROCESS_START'):
......@@ -242,6 +244,22 @@ def parse_args():
parser_algo_list = parser_algo_subparsers.add_parser('list', help='list registered algorithms')
parser_algo_list.set_defaults(func=algo_list)
#parse trainingservice command
parser_ts = subparsers.add_parser('trainingservice', help='control training service')
# add subparsers for parser_ts
parser_ts_subparsers = parser_ts.add_subparsers()
parser_ts_reg = parser_ts_subparsers.add_parser('register', help='register training service')
parser_ts_reg.add_argument('--package', dest='package', help='package name', required=True)
parser_ts_reg.set_defaults(func=ts_management.register)
parser_ts_unreg = parser_ts_subparsers.add_parser('unregister', help='unregister training service')
parser_ts_unreg.add_argument('--package', dest='package', help='package name', required=True)
parser_ts_unreg.set_defaults(func=ts_management.unregister)
parser_ts_list = parser_ts_subparsers.add_parser('list', help='list custom training services')
parser_ts_list.set_defaults(func=ts_management.list_services)
# To show message that nnictl package command is replaced by nnictl algo, to be remove in the future release.
def show_messsage_for_nnictl_package(args):
print_error('nnictl package command is replaced by nnictl algo, please run nnictl algo -h to show the usage')
......
import importlib
import json
from nni.runtime.config import get_config_file
from .common_utils import print_error, print_green
_builtin_training_services = [
'local',
'remote',
'openpai', 'pai',
'aml',
'kubeflow',
'frameworkcontroller',
'adl',
]
def register(args):
if args.package in _builtin_training_services:
print_error(f'{args.package} is a builtin training service')
return
try:
module = importlib.import_module(args.package)
except Exception:
print_error(f'Cannot import package {args.package}')
return
try:
info = module.nni_training_service_info
except Exception:
print_error(f'Cannot read nni_training_service_info from {args.package}')
return
try:
info.config_class()
except Exception:
print_error('Bad experiment config class')
return
try:
service_config = {
'node_module_path': info.node_module_path,
'node_class_name': info.node_class_name,
}
json.dumps(service_config)
except Exception:
print_error('Bad node_module_path or bad node_class_name')
return
config = _load()
update = args.package in config
config[args.package] = service_config
_save(config)
if update:
print_green(f'Sucessfully updated {args.package}')
else:
print_green(f'Sucessfully registered {args.package}')
def unregister(args):
config = _load()
if args.package not in config:
print_error(f'{args.package} is not a registered training service')
return
config.pop(args.package, None)
_save(config)
print_green(f'Sucessfully unregistered {args.package}')
def list_services(_):
print('\n'.join(_load().keys()))
def _load():
return json.load(get_config_file('training_services.json').open())
def _save(config):
json.dump(config, get_config_file('training_services.json').open('w'), indent=4)
......@@ -5,6 +5,7 @@
import * as fs from 'fs';
import { Writable } from 'stream';
import * as util from 'util';
/* log level constants */
......@@ -28,7 +29,6 @@ const levelNames = new Map<number, string>([
/* global_ states */
let logFile: Writable | null = null;
let logLevel: number = 0;
const loggers = new Map<string, Logger>();
......@@ -70,7 +70,8 @@ export class Logger {
}
private log(level: number, args: any[]): void {
if (level < logLevel || logFile === null) {
const logFile: Writable | undefined = (global as any).logFile;
if (level < logLevel || logFile === undefined) {
return;
}
......@@ -80,20 +81,7 @@ export class Logger {
const levelName = levelNames.has(level) ? levelNames.get(level) : level.toString();
const words = [];
for (const arg of args) {
if (arg === undefined) {
words.push('undefined');
} else if (arg === null) {
words.push('null');
} else if (typeof arg === 'object') {
const json = JSON.stringify(arg);
words.push(json === undefined ? arg : json);
} else {
words.push(arg);
}
}
const message = words.join(' ');
const message = args.map(arg => (typeof arg === 'string' ? arg : util.inspect(arg))).join(' ');
const record = `[${time}] ${levelName} (${this.name}) ${message}\n`;
logFile.write(record);
......@@ -124,7 +112,7 @@ export function setLogLevel(levelName: string): void {
}
export function startLogging(logPath: string): void {
logFile = fs.createWriteStream(logPath, {
(global as any).logFile = fs.createWriteStream(logPath, {
flags: 'a+',
encoding: 'utf8',
autoClose: true
......@@ -132,8 +120,8 @@ export function startLogging(logPath: string): void {
}
export function stopLogging(): void {
if (logFile !== null) {
logFile.end();
logFile = null;
if ((global as any).logFile !== undefined) {
(global as any).logFile.end();
(global as any).logFile = undefined;
}
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as fs from 'fs';
import * as path from 'path';
import { promisify } from 'util';
import { runPythonScript } from './pythonScript';
export interface CustomEnvironmentServiceConfig {
name: string;
nodeModulePath: string;
nodeClassName: string;
}
const readFile = promisify(fs.readFile);
async function readConfigFile(fileName: string): Promise<string> {
const script = 'import nni.runtime.config ; print(nni.runtime.config.get_config_directory())';
const configDir = (await runPythonScript(script)).trim();
const stream = await readFile(path.join(configDir, fileName));
return stream.toString();
}
export async function getCustomEnvironmentServiceConfig(name: string): Promise<CustomEnvironmentServiceConfig | null> {
const configJson = await readConfigFile('training_services.json');
const config = JSON.parse(configJson);
if (config[name] === undefined) {
return null;
}
return {
name,
nodeModulePath: config[name].nodeModulePath as string,
nodeClassName: config[name].nodeClassName as string,
}
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import { spawn } from 'child_process';
import { Logger, getLogger } from './log';
const python = process.platform === 'win32' ? 'python.exe' : 'python3';
export async function runPythonScript(script: string, logger?: Logger): Promise<string> {
const proc = spawn(python, [ '-c', script ]);
const procPromise = new Promise<void>((resolve, reject) => {
proc.on('error', (err: Error) => { reject(err); });
proc.on('exit', () => { resolve(); });
});
await procPromise;
const stdout = proc.stdout.read().toString();
const stderr = proc.stderr.read().toString();
if (stderr) {
if (logger === undefined) {
logger = getLogger();
}
logger.warning('python script has stderr.');
logger.warning('script:', script);
logger.warning('stderr:', stderr);
}
return stdout;
}
......@@ -25,8 +25,7 @@ import { ExperimentManager } from './experimentManager';
import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService';
function getExperimentRootDir(): string {
return getExperimentStartupInfo()
.getLogDir();
return getExperimentStartupInfo().getLogDir();
}
function getLogDir(): string {
......@@ -34,8 +33,7 @@ function getLogDir(): string {
}
function getLogLevel(): string {
return getExperimentStartupInfo()
.getLogLevel();
return getExperimentStartupInfo().getLogLevel();
}
function getDefaultDatabaseDir(): string {
......@@ -481,6 +479,11 @@ async function getFreePort(host: string, start: number, end: number): Promise<nu
}
}
export function importModule(modulePath: string): any {
module.paths.unshift(path.dirname(modulePath));
return require(path.basename(modulePath));
}
export {
countFilesRecursively, validateFileNameRecursively, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir, getExperimentsInfoPath,
getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin, withLockSync, getFreePort, isPortOpen,
......
......@@ -445,10 +445,7 @@ class NNIManager implements Manager {
throw new Error('Cannot detect training service platform');
}
if (['remote', 'pai', 'aml', 'hybrid'].includes(platform)) {
const module_ = await import('../training_service/reusable/routerTrainingService');
return new module_.RouterTrainingService(config);
} else if (platform === 'local') {
if (platform === 'local') {
const module_ = await import('../training_service/local/localTrainingService');
return new module_.LocalTrainingService(config);
} else if (platform === 'kubeflow') {
......@@ -460,6 +457,9 @@ class NNIManager implements Manager {
} else if (platform === 'adl') {
const module_ = await import('../training_service/kubernetes/adl/adlTrainingService');
return new module_.AdlTrainingService();
} else {
const module_ = await import('../training_service/reusable/routerTrainingService');
return await module_.RouterTrainingService.construct(config);
}
throw new Error(`Unsupported training service platform "${platform}"`);
......
......@@ -83,11 +83,6 @@ const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : fals
const port: number = parseInt(strPort, 10);
const mode: string = parseArg(['--mode', '-m']);
if (!['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'dlts', 'aml', 'adl', 'hybrid'].includes(mode)) {
console.log(`FATAL: unknown mode: ${mode}`);
usage();
process.exit(1);
}
const startMode: string = parseArg(['--start_mode', '-s']);
if (![ExperimentStartUpMode.NEW, ExperimentStartUpMode.RESUME].includes(startMode)) {
......
......@@ -6,9 +6,7 @@
import * as fs from 'fs';
import * as path from 'path';
import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../../common/log';
import { getExperimentRootDir } from '../../../common/utils';
import { ExperimentConfig, AmlConfig, flattenConfig } from '../../../common/experimentConfig';
import { validateCodeDir } from '../../common/util';
import { AMLClient } from '../aml/amlClient';
......@@ -31,10 +29,10 @@ export class AMLEnvironmentService extends EnvironmentService {
private experimentRootDir: string;
private config: FlattenAmlConfig;
constructor(config: ExperimentConfig) {
constructor(experimentRootDir: string, experimentId: string, config: ExperimentConfig) {
super();
this.experimentId = getExperimentId();
this.experimentRootDir = getExperimentRootDir();
this.experimentId = experimentId;
this.experimentRootDir = experimentRootDir;
this.config = flattenConfig(config, 'aml');
validateCodeDir(this.config.trialCodeDirectory);
}
......
......@@ -4,20 +4,31 @@ import { LocalEnvironmentService } from './localEnvironmentService';
import { RemoteEnvironmentService } from './remoteEnvironmentService';
import { EnvironmentService } from '../environment';
import { ExperimentConfig } from '../../../common/experimentConfig';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getCustomEnvironmentServiceConfig } from '../../../common/nniConfig';
import { getExperimentRootDir, importModule } from '../../../common/utils';
export async function createEnvironmentService(name: string, config: ExperimentConfig): Promise<EnvironmentService> {
const expId = getExperimentId();
const rootDir = getExperimentRootDir();
export class EnvironmentServiceFactory {
public static createEnvironmentService(name: string, config: ExperimentConfig): EnvironmentService {
switch(name) {
case 'local':
return new LocalEnvironmentService(config);
return new LocalEnvironmentService(rootDir, expId, config);
case 'remote':
return new RemoteEnvironmentService(config);
return new RemoteEnvironmentService(rootDir, expId, config);
case 'aml':
return new AMLEnvironmentService(config);
return new AMLEnvironmentService(rootDir, expId, config);
case 'openpai':
return new OpenPaiEnvironmentService(config);
default:
throw new Error(`${name} not supported!`);
return new OpenPaiEnvironmentService(rootDir, expId, config);
}
const esConfig = await getCustomEnvironmentServiceConfig(name);
if (esConfig === null) {
throw new Error(`${name} is not a supported training service!`);
}
const esModule = importModule(esConfig.nodeModulePath);
const esClass = esModule[esConfig.nodeClassName] as any;
return new esClass(rootDir, expId, config);
}
......@@ -7,11 +7,10 @@ import * as fs from 'fs';
import * as path from 'path';
import * as tkill from 'tree-kill';
import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../../common/log';
import { ExperimentConfig } from '../../../common/experimentConfig';
import { EnvironmentInformation, EnvironmentService } from '../environment';
import { getExperimentRootDir, isAlive, getNewLine } from '../../../common/utils';
import { isAlive, getNewLine } from '../../../common/utils';
import { execMkdir, runScript, getScriptName, execCopydir } from '../../common/util';
import { SharedStorageService } from '../sharedStorage'
......@@ -22,10 +21,10 @@ export class LocalEnvironmentService extends EnvironmentService {
private experimentRootDir: string;
private experimentId: string;
constructor(_config: ExperimentConfig) {
constructor(experimentRootDir: string, experimentId: string, _config: ExperimentConfig) {
super();
this.experimentId = getExperimentId();
this.experimentRootDir = getExperimentRootDir();
this.experimentId = experimentId;
this.experimentRootDir = experimentRootDir;
}
public get environmentMaintenceLoopInterval(): number {
......@@ -110,8 +109,6 @@ export class LocalEnvironmentService extends EnvironmentService {
const sharedStorageService = component.get<SharedStorageService>(SharedStorageService);
if (environment.useSharedStorage && sharedStorageService.canLocalMounted) {
this.experimentRootDir = sharedStorageService.localWorkingRoot;
} else {
this.experimentRootDir = getExperimentRootDir();
}
const localEnvCodeFolder: string = path.join(this.experimentRootDir, "envs");
if (environment.useSharedStorage && !sharedStorageService.canLocalMounted) {
......
......@@ -7,7 +7,6 @@ import * as yaml from 'js-yaml';
import * as request from 'request';
import { Deferred } from 'ts-deferred';
import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { ExperimentConfig, OpenpaiConfig, flattenConfig, toMegaBytes } from '../../../common/experimentConfig';
import { getLogger, Logger } from '../../../common/log';
import { PAIClusterConfig } from '../../pai/paiConfig';
......@@ -32,9 +31,9 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
private experimentId: string;
private config: FlattenOpenpaiConfig;
constructor(config: ExperimentConfig) {
constructor(_experimentRootDir: string, experimentId: string, config: ExperimentConfig) {
super();
this.experimentId = getExperimentId();
this.experimentId = experimentId;
this.config = flattenConfig(config, 'openpai');
this.paiToken = this.config.token;
this.protocol = this.config.host.toLowerCase().startsWith('https://') ? 'https' : 'http';
......
......@@ -6,10 +6,9 @@
import * as fs from 'fs';
import * as path from 'path';
import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../../common/log';
import { EnvironmentInformation, EnvironmentService } from '../environment';
import { getExperimentRootDir, getLogLevel } from '../../../common/utils';
import { getLogLevel } from '../../../common/utils';
import { ExperimentConfig, RemoteConfig, RemoteMachineConfig, flattenConfig } from '../../../common/experimentConfig';
import { execMkdir } from '../../common/util';
import { ExecutorManager } from '../../remote_machine/remoteMachineData';
......@@ -33,14 +32,13 @@ export class RemoteEnvironmentService extends EnvironmentService {
private experimentId: string;
private config: FlattenRemoteConfig;
constructor(config: ExperimentConfig) {
constructor(experimentRootDir: string, experimentId: string, config: ExperimentConfig) {
super();
this.experimentId = getExperimentId();
this.experimentId = experimentId;
this.environmentExecutorManagerMap = new Map<string, ExecutorManager>();
this.machineExecutorManagerMap = new Map<RemoteMachineConfig, ExecutorManager>();
this.remoteMachineMetaOccupiedMap = new Map<RemoteMachineConfig, boolean>();
this.experimentRootDir = getExperimentRootDir();
this.experimentId = getExperimentId();
this.experimentRootDir = experimentRootDir;
this.log = getLogger();
this.config = flattenConfig(config, 'remote');
......@@ -103,10 +101,10 @@ export class RemoteEnvironmentService extends EnvironmentService {
// Create root working directory after executor is ready
const nniRootDir: string = executor.joinPath(executor.getTempPath(), 'nni-experiments');
await executor.createFolder(executor.getRemoteExperimentRootDir(getExperimentId()));
await executor.createFolder(executor.getRemoteExperimentRootDir(this.experimentId));
// the directory to store temp scripts in remote machine
const remoteGpuScriptCollectorDir: string = executor.getRemoteScriptsPath(getExperimentId());
const remoteGpuScriptCollectorDir: string = executor.getRemoteScriptsPath(this.experimentId);
// clean up previous result.
await executor.createFolder(remoteGpuScriptCollectorDir, true);
......@@ -245,7 +243,7 @@ export class RemoteEnvironmentService extends EnvironmentService {
throw new Error(`Mount shared storage on remote machine failed.\n ERROR: ${result.stderr}`);
}
} else {
this.remoteExperimentRootDir = executor.getRemoteExperimentRootDir(getExperimentId());
this.remoteExperimentRootDir = executor.getRemoteExperimentRootDir(this.experimentId);
}
environment.command = await this.getScript(environment);
......
......@@ -3,7 +3,6 @@
'use strict';
import * as component from '../../common/component';
import { getLogger, Logger } from '../../common/log';
import { MethodNotImplementedError } from '../../common/errors';
import { ExperimentConfig, RemoteConfig, OpenpaiConfig } from '../../common/experimentConfig';
......@@ -18,23 +17,27 @@ import { TrialDispatcher } from './trialDispatcher';
* It's a intermedia implementation to support reusable training service.
* The final goal is to support reusable training job in higher level than training service.
*/
@component.Singleton
class RouterTrainingService implements TrainingService {
protected readonly log: Logger;
private internalTrainingService: TrainingService;
private log!: Logger;
private internalTrainingService!: TrainingService;
constructor(config: ExperimentConfig) {
this.log = getLogger();
public static async construct(config: ExperimentConfig): Promise<RouterTrainingService> {
const instance = new RouterTrainingService();
instance.log = getLogger('RouterTrainingService');
const platform = Array.isArray(config.trainingService) ? 'hybrid' : config.trainingService.platform;
if (platform === 'remote' && !(<RemoteConfig>config.trainingService).reuseMode) {
this.internalTrainingService = new RemoteMachineTrainingService(config);
instance.internalTrainingService = new RemoteMachineTrainingService(config);
} else if (platform === 'openpai' && !(<OpenpaiConfig>config.trainingService).reuseMode) {
this.internalTrainingService = new PAITrainingService(config);
instance.internalTrainingService = new PAITrainingService(config);
} else {
this.internalTrainingService = new TrialDispatcher(config);
instance.internalTrainingService = await TrialDispatcher.construct(config);
}
return instance;
}
// eslint-disable-next-line @typescript-eslint/no-empty-function
private constructor() { }
public async listTrialJobs(): Promise<TrialJobDetail[]> {
if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!");
......
......@@ -203,7 +203,7 @@ describe('Unit Test for TrialDispatcher', () => {
});
beforeEach(async () => {
trialDispatcher = new TrialDispatcher(config);
trialDispatcher = await TrialDispatcher.construct(config);
// set ut environment
let environmentServiceList: EnvironmentService[] = [];
......
......@@ -24,7 +24,7 @@ import { TrialConfig } from '../common/trialConfig';
import { validateCodeDir } from '../common/util';
import { Command, CommandChannel } from './commandChannel';
import { EnvironmentInformation, EnvironmentService, NodeInformation, RunnerSettings, TrialGpuSummary } from './environment';
import { EnvironmentServiceFactory } from './environments/environmentServiceFactory';
import { createEnvironmentService } from './environments/environmentServiceFactory';
import { GpuScheduler } from './gpuScheduler';
import { MountedStorageService } from './storages/mountedStorageService';
import { StorageService } from './storageService';
......@@ -39,20 +39,20 @@ import { TrialDetail } from './trial';
**/
@component.Singleton
class TrialDispatcher implements TrainingService {
private readonly log: Logger;
private readonly isDeveloping: boolean = false;
private log: Logger;
private isDeveloping: boolean = false;
private stopping: boolean = false;
private readonly metricsEmitter: EventEmitter;
private readonly experimentId: string;
private readonly experimentRootDir: string;
private metricsEmitter: EventEmitter;
private experimentId: string;
private experimentRootDir: string;
private enableVersionCheck: boolean = true;
private trialConfig: TrialConfig | undefined;
private readonly trials: Map<string, TrialDetail>;
private readonly environments: Map<string, EnvironmentInformation>;
private trials: Map<string, TrialDetail>;
private environments: Map<string, EnvironmentInformation>;
// make public for ut
public environmentServiceList: EnvironmentService[] = [];
public commandChannelSet: Set<CommandChannel>;
......@@ -82,8 +82,14 @@ class TrialDispatcher implements TrainingService {
private config: ExperimentConfig;
constructor(config: ExperimentConfig) {
this.log = getLogger();
public static async construct(config: ExperimentConfig): Promise<TrialDispatcher> {
const instance = new TrialDispatcher(config);
await instance.asyncConstructor(config);
return instance;
}
private constructor(config: ExperimentConfig) {
this.log = getLogger('TrialDispatcher');
this.trials = new Map<string, TrialDetail>();
this.environments = new Map<string, EnvironmentInformation>();
this.metricsEmitter = new EventEmitter();
......@@ -109,18 +115,14 @@ class TrialDispatcher implements TrainingService {
if (this.enableGpuScheduler) {
this.log.info(`TrialDispatcher: GPU scheduler is enabled.`)
}
}
validateCodeDir(config.trialCodeDirectory);
private async asyncConstructor(config: ExperimentConfig): Promise<void> {
await validateCodeDir(config.trialCodeDirectory);
if (Array.isArray(config.trainingService)) {
config.trainingService.forEach(trainingService => {
const env = EnvironmentServiceFactory.createEnvironmentService(trainingService.platform, config);
this.environmentServiceList.push(env);
});
} else {
const env = EnvironmentServiceFactory.createEnvironmentService(config.trainingService.platform, config);
this.environmentServiceList.push(env);
}
const serviceConfigs = Array.isArray(config.trainingService) ? config.trainingService : [ config.trainingService ];
const servicePromises = serviceConfigs.map(serviceConfig => createEnvironmentService(serviceConfig.platform, config));
this.environmentServiceList = await Promise.all(servicePromises);
this.environmentMaintenceLoopInterval = Math.max(
...this.environmentServiceList.map((env) => env.environmentMaintenceLoopInterval)
......@@ -132,7 +134,7 @@ class TrialDispatcher implements TrainingService {
}
if (this.config.sharedStorage !== undefined) {
this.initializeSharedStorage(this.config.sharedStorage);
await this.initializeSharedStorage(this.config.sharedStorage);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment