Unverified Commit 872554f1 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Support heterogeneous environment service (#3097)

parent dec91f7e
**Run an Experiment on Heterogeneous Mode**
===========================================
Run NNI on heterogeneous mode means that NNI will run trials jobs in multiple kinds of training platforms. For example, NNI could submit trial jobs to remote machine and AML simultaneously。
## Setup environment
NNI has supported [local](./LocalMode.md), [remote](./RemoteMachineMode.md), [pai](./PaiMode.md) and [AML](./AMLMode.md) for heterogeneous training service. Before starting an experiment using these mode, users should setup the corresponding environment for the platforms. More details about the environment setup could be found in the corresponding docs.
## Run an experiment
Use `examples/trials/mnist-tfv1` as an example. The NNI config YAML file's content is like:
.. code-block:: yaml
authorName: default
experimentName: example_mnist
trialConcurrency: 2
maxExecDuration: 1h
maxTrialNum: 10
trainingServicePlatform: heterogeneous
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
builtinTunerName: TPE
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python3 mnist.py
codeDir: .
gpuNum: 1
heterogeneousConfig:
trainingServicePlatforms:
- local
- remote
remoteConfig:
reuse: true
machineList:
- ip: 10.1.1.1
username: bob
passwd: bob123
Configurations for heterogeneous mode:
heterogeneousConfig:
* trainingServicePlatforms. required key. This field specify the platforms used in heterogeneous mode, the values using yaml list format. NNI support setting `local`, `remote`, `aml`, `pai` in this field.
Note:
If setting a platform in trainingServicePlatforms mode, users should also set the corresponding configuration for the platform. For example, if set `remote` as one of the platform, should also set `machineList` and `remoteConfig` configuration.
......@@ -12,3 +12,4 @@ Introduction to NNI Training Services
FrameworkController<./TrainingService/FrameworkControllerMode>
DLTS<./TrainingService/DLTSMode>
AML<./TrainingService/AMLMode>
Heterogeneous<./TrainingService/HeterogeneousMode>
authorName: default
experimentName: example_mnist
trialConcurrency: 3
maxExecDuration: 1h
maxTrialNum: 10
trainingServicePlatform: heterogeneous
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName: TPE
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python3 mnist.py
codeDir: .
gpuNum: 0
heterogeneousConfig:
trainingServicePlatforms:
- local
- remote
remoteConfig:
reuse: true
machineList:
- ip: 10.1.1.1
username: bob
passwd: bob123
#port can be skip if using default ssh port 22
#port: 22
\ No newline at end of file
......@@ -12,7 +12,8 @@ _trial_env_var_names = [
'NNI_SYS_DIR',
'NNI_OUTPUT_DIR',
'NNI_TRIAL_SEQ_ID',
'MULTI_PHASE'
'MULTI_PHASE',
'REUSE_MODE'
]
_dispatcher_env_var_names = [
......
......@@ -31,7 +31,7 @@ def init_logger() -> None:
if trial_platform == 'unittest':
return
if trial_platform:
if trial_platform and not trial_env_vars.REUSE_MODE:
_init_logger_trial()
return
......
......@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None:
from .standalone import *
elif trial_env_vars.NNI_PLATFORM == 'unittest':
from .test import *
elif trial_env_vars.NNI_PLATFORM in ('adl', 'local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'):
elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml', 'adl', 'heterogeneous'):
from .local import *
else:
raise RuntimeError('Unknown platform %s' % trial_env_vars.NNI_PLATFORM)
......@@ -19,6 +19,7 @@ _outputdir = trial_env_vars.NNI_OUTPUT_DIR
if not os.path.exists(_outputdir):
os.makedirs(_outputdir)
_reuse_mode = trial_env_vars.REUSE_MODE
_nni_platform = trial_env_vars.NNI_PLATFORM
_multiphase = trial_env_vars.MULTI_PHASE
......@@ -58,7 +59,7 @@ def get_next_parameter():
return params
def send_metric(string):
if _nni_platform != 'local':
if _nni_platform != 'local' or _reuse_mode in ('true', 'True'):
assert len(string) < 1000000, 'Metric too long'
print("NNISDK_MEb'%s'" % (string), flush=True)
else:
......
......@@ -124,7 +124,7 @@ common_schema = {
Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')),
Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999),
'trainingServicePlatform': setChoice(
'trainingServicePlatform', 'adl', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'),
'trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml', 'adl', 'heterogeneous'),
Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'),
Optional('multiPhase'): setType('multiPhase', bool),
Optional('multiThread'): setType('multiThread', bool),
......@@ -208,7 +208,7 @@ pai_trial_schema = {
}
pai_config_schema = {
'paiConfig': {
Optional('paiConfig'): {
'userName': setType('userName', str),
Or('passWord', 'token', only_one=True): str,
'host': setType('host', str),
......@@ -252,7 +252,7 @@ aml_trial_schema = {
}
aml_config_schema = {
'amlConfig': {
Optional('amlConfig'): {
'subscriptionId': setType('subscriptionId', str),
'resourceGroup': setType('resourceGroup', str),
'workspaceName': setType('workspaceName', str),
......@@ -262,6 +262,29 @@ aml_config_schema = {
}
}
heterogeneous_trial_schema = {
'trial': {
'codeDir': setPathCheck('codeDir'),
Optional('nniManagerNFSMountPath'): setPathCheck('nniManagerNFSMountPath'),
Optional('containerNFSMountPath'): setType('containerNFSMountPath', str),
Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
'command': setType('command', str),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
Optional('cpuNum'): setNumberRange('cpuNum', int, 0, 99999),
Optional('memoryMB'): setType('memoryMB', int),
Optional('image'): setType('image', str),
Optional('virtualCluster'): setType('virtualCluster', str),
Optional('paiStorageConfigName'): setType('paiStorageConfigName', str),
Optional('paiConfigPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'paiConfigPath')
}
}
heterogeneous_config_schema = {
'heterogeneousConfig': {
'trainingServicePlatforms': ['local', 'remote', 'pai', 'aml']
}
}
adl_trial_schema = {
'trial':{
'codeDir': setType('codeDir', str),
......@@ -404,7 +427,7 @@ remote_config_schema = {
}
machine_list_schema = {
'machineList': [Or(
Optional('machineList'): [Or(
{
'ip': setType('ip', str),
Optional('port'): setNumberRange('port', int, 1, 65535),
......@@ -438,6 +461,8 @@ training_service_schema_dict = {
'frameworkcontroller': Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema}),
'aml': Schema({**common_schema, **aml_trial_schema, **aml_config_schema}),
'dlts': Schema({**common_schema, **dlts_trial_schema, **dlts_config_schema}),
'heterogeneous': Schema({**common_schema, **heterogeneous_trial_schema, **heterogeneous_config_schema, **machine_list_schema,
**pai_config_schema, **aml_config_schema, **remote_config_schema}),
}
......@@ -454,6 +479,7 @@ class NNIConfigSchema:
self.validate_pai_trial_conifg(experiment_config)
self.validate_kubeflow_operators(experiment_config)
self.validate_eth0_device(experiment_config)
self.validate_heterogeneous_platforms(experiment_config)
def validate_tuner_adivosr_assessor(self, experiment_config):
if experiment_config.get('advisor'):
......@@ -563,3 +589,16 @@ class NNIConfigSchema:
and not experiment_config.get('nniManagerIp') \
and 'eth0' not in netifaces.interfaces():
raise SchemaError('This machine does not contain eth0 network device, please set nniManagerIp in config file!')
def validate_heterogeneous_platforms(self, experiment_config):
required_config_name_map = {
'remote': 'machineList',
'aml': 'amlConfig',
'pai': 'paiConfig'
}
if experiment_config.get('trainingServicePlatform') == 'heterogeneous':
for platform in experiment_config['heterogeneousConfig']['trainingServicePlatforms']:
config_name = required_config_name_map.get(platform)
if config_name and not experiment_config.get(config_name):
raise SchemaError('Need to set {0} for {1} in heterogeneous mode!'.format(config_name, platform))
\ No newline at end of file
......@@ -118,13 +118,6 @@ def set_local_config(experiment_config, port, config_file_name):
request_data = dict()
if experiment_config.get('localConfig'):
request_data['local_config'] = experiment_config['localConfig']
if request_data['local_config']:
if request_data['local_config'].get('gpuIndices') and isinstance(request_data['local_config'].get('gpuIndices'), int):
request_data['local_config']['gpuIndices'] = str(request_data['local_config'].get('gpuIndices'))
if request_data['local_config'].get('maxTrialNumOnEachGpu'):
request_data['local_config']['maxTrialNumOnEachGpu'] = request_data['local_config'].get('maxTrialNumOnEachGpu')
if request_data['local_config'].get('useActiveGpu'):
request_data['local_config']['useActiveGpu'] = request_data['local_config'].get('useActiveGpu')
response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT)
err_message = ''
if not response or not check_response(response):
......@@ -306,6 +299,37 @@ def set_aml_config(experiment_config, port, config_file_name):
#set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message
def set_heterogeneous_config(experiment_config, port, config_file_name):
'''set heterogeneous configuration'''
heterogeneous_config_data = dict()
heterogeneous_config_data['heterogeneous_config'] = experiment_config['heterogeneousConfig']
platform_list = experiment_config['heterogeneousConfig']['trainingServicePlatforms']
for platform in platform_list:
if platform == 'aml':
heterogeneous_config_data['aml_config'] = experiment_config['amlConfig']
elif platform == 'remote':
if experiment_config.get('remoteConfig'):
heterogeneous_config_data['remote_config'] = experiment_config['remoteConfig']
heterogeneous_config_data['machine_list'] = experiment_config['machineList']
elif platform == 'local' and experiment_config.get('localConfig'):
heterogeneous_config_data['local_config'] = experiment_config['localConfig']
elif platform == 'pai':
heterogeneous_config_data['pai_config'] = experiment_config['paiConfig']
response = rest_put(cluster_metadata_url(port), json.dumps(heterogeneous_config_data), REST_TIME_OUT)
err_message = None
if not response or not response.status_code == 200:
if response is not None:
err_message = response.text
_, stderr_full_path = get_log_path(config_file_name)
with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message
result, message = setNNIManagerIp(experiment_config, port, config_file_name)
if not result:
return result, message
#set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message
def set_experiment(experiment_config, mode, port, config_file_name):
'''Call startExperiment (rest POST /experiment) with yaml file content'''
request_data = dict()
......@@ -387,6 +411,21 @@ def set_experiment(experiment_config, mode, port, config_file_name):
{'key': 'aml_config', 'value': experiment_config['amlConfig']})
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
elif experiment_config['trainingServicePlatform'] == 'heterogeneous':
request_data['clusterMetaData'].append(
{'key': 'heterogeneous_config', 'value': experiment_config['heterogeneousConfig']})
platform_list = experiment_config['heterogeneousConfig']['trainingServicePlatforms']
request_dict = {
'aml': {'key': 'aml_config', 'value': experiment_config.get('amlConfig')},
'remote': {'key': 'machine_list', 'value': experiment_config.get('machineList')},
'pai': {'key': 'pai_config', 'value': experiment_config.get('paiConfig')},
'local': {'key': 'local_config', 'value': experiment_config.get('localConfig')}
}
for platform in platform_list:
if request_dict.get(platform):
request_data['clusterMetaData'].append(request_dict[platform])
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT, show_error=True)
if check_response(response):
return response
......@@ -420,6 +459,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
config_result, err_msg = set_dlts_config(experiment_config, port, config_file_name)
elif platform == 'aml':
config_result, err_msg = set_aml_config(experiment_config, port, config_file_name)
elif platform == 'heterogeneous':
config_result, err_msg = set_heterogeneous_config(experiment_config, port, config_file_name)
else:
raise Exception(ERROR_INFO % 'Unsupported platform!')
exit(1)
......
......@@ -25,7 +25,6 @@ def main_loop(args):
'''main loop logic for trial runner'''
idle_last_time = datetime.now()
gpu_refresh_last_time = datetime.now() - timedelta(minutes=1)
try:
if args.job_pid_file:
with open(args.job_pid_file, 'w') as job_file:
......@@ -188,6 +187,7 @@ if __name__ == '__main__':
os.environ['NNI_EXP_ID'] = args.exp_id
os.environ['MULTI_PHASE'] = "true"
os.environ['NNI_TRIAL_JOB_ID'] = "runner"
os.environ['REUSE_MODE'] = "true"
from .log_utils import LogType, RemoteLogger, StdOutputType, nni_log
from .trial import Trial
......
......@@ -28,6 +28,7 @@ import { RouterTrainingService } from './training_service/reusable/routerTrainin
import { PAIYarnTrainingService } from './training_service/pai/paiYarn/paiYarnTrainingService';
import { DLTSTrainingService } from './training_service/dlts/dltsTrainingService';
function initStartupInfo(
startExpMode: string, experimentId: string, basePort: number, platform: string,
logDirectory: string, experimentLogLevel: string, readonly: boolean, dispatcherPipe: string): void {
......@@ -36,22 +37,15 @@ function initStartupInfo(
}
async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise<void> {
if (platformMode === 'adl') {
const routerPlatformMode = ['remote', 'pai', 'aml', 'heterogeneous'];
if (routerPlatformMode.includes(platformMode)) {
Container.bind(TrainingService)
.to(AdlTrainingService)
.to(RouterTrainingService)
.scope(Scope.Singleton);
} else if (platformMode === 'local') {
Container.bind(TrainingService)
.to(LocalTrainingService)
.scope(Scope.Singleton);
} else if (platformMode === 'remote') {
Container.bind(TrainingService)
.to(RouterTrainingService)
.scope(Scope.Singleton);
} else if (platformMode === 'pai') {
Container.bind(TrainingService)
.to(RouterTrainingService)
.scope(Scope.Singleton);
} else if (platformMode === 'paiYarn') {
Container.bind(TrainingService)
.to(PAIYarnTrainingService)
......@@ -68,9 +62,9 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
Container.bind(TrainingService)
.to(DLTSTrainingService)
.scope(Scope.Singleton);
} else if (platformMode === 'aml') {
} else if (platformMode === 'adl') {
Container.bind(TrainingService)
.to(RouterTrainingService)
.to(AdlTrainingService)
.scope(Scope.Singleton);
} else {
throw new Error(`Error: unsupported mode: ${platformMode}`);
......@@ -103,7 +97,7 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
function usage(): void {
console.info('usage: node main.js --port <port> --mode \
<adl/local/remote/pai/kubeflow/frameworkcontroller/paiYarn/aml> --start_mode <new/resume> --experiment_id <id> --foreground <true/false>');
<local/remote/pai/kubeflow/frameworkcontroller/paiYarn/aml/adl/heterogeneous> --start_mode <new/resume> --experiment_id <id> --foreground <true/false>');
}
const strPort: string = parseArg(['--port', '-p']);
......@@ -123,7 +117,7 @@ const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : fals
const port: number = parseInt(strPort, 10);
const mode: string = parseArg(['--mode', '-m']);
if (!['adl', 'local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'].includes(mode)) {
if (!['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml', 'adl', 'heterogeneous'].includes(mode)) {
console.log(`FATAL: unknown mode: ${mode}`);
usage();
process.exit(1);
......
......@@ -23,7 +23,8 @@ export namespace ValidationSchemas {
local_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
gpuIndices: joi.string(),
maxTrialNumPerGpu: joi.number(),
useActiveGpu: joi.boolean()
useActiveGpu: joi.boolean(),
reuse: joi.boolean()
}),
trial_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
image: joi.string().min(1),
......@@ -182,6 +183,9 @@ export namespace ValidationSchemas {
maxTrialNumPerGpu: joi.number(),
useActiveGpu: joi.boolean()
}),
heterogeneous_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
trainingServicePlatforms: joi.array(),
}),
nni_manager_ip: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
nniManagerIp: joi.string().min(1)
}),
......
......@@ -11,6 +11,7 @@ export enum TrialConfigMetadataKey {
LOCAL_CONFIG = 'local_config',
TRIAL_CONFIG = 'trial_config',
REMOTE_CONFIG = 'remote_config',
HETEROGENEOUS_CONFIG = 'heterogeneous_config',
EXPERIMENT_ID = 'experimentId',
MULTI_PHASE = 'multiPhase',
RANDOM_SCHEDULER = 'random_scheduler',
......@@ -22,5 +23,8 @@ export enum TrialConfigMetadataKey {
DLTS_CLUSTER_CONFIG = 'dlts_config',
AML_CLUSTER_CONFIG = 'aml_config',
VERSION_CHECK = 'version_check',
LOG_COLLECTION = 'log_collection'
LOG_COLLECTION = 'log_collection',
// Used to set platform for heterogeneous in reuse mode,
// temproarily change and will refactor config schema in the future
PLATFORM_LIST = 'platform_list'
}
......@@ -78,7 +78,7 @@ class LocalTrialJobDetail implements TrialJobDetail {
/**
* Local training service config
*/
class LocalConfig {
export class LocalConfig {
public maxTrialNumPerGpu?: number;
public gpuIndices?: string;
public useActiveGpu?: boolean;
......
......@@ -358,6 +358,10 @@ class RemoteMachineTrainingService implements TrainingService {
case TrialConfigMetadataKey.LOG_COLLECTION:
this.logCollection = value;
break;
case TrialConfigMetadataKey.REMOTE_CONFIG:
// Add remote_config in remoteEnvironmentService to set reuse mode,
// this config need to be catched here, otherwise will throw Unknown key exception here
break;
default:
//Reject for unknown keys
throw new Error(`Uknown key: ${key}`);
......
......@@ -8,6 +8,7 @@ import { getBasePort, getExperimentId } from "../../../common/experimentStartupI
import { INITIALIZED } from '../../../core/commands';
import { CommandChannel, RunnerConnection } from "../commandChannel";
import { Channel, EnvironmentInformation } from "../environment";
import { EventEmitter } from "events";
class WebRunnerConnection extends RunnerConnection {
public readonly clients: WebSocket[] = [];
......@@ -29,7 +30,7 @@ class WebRunnerConnection extends RunnerConnection {
export class WebCommandChannel extends CommandChannel {
private readonly expId: string = getExperimentId();
private static commandChannel: WebCommandChannel;
private webSocketServer: SocketServer | undefined;
private clients: Map<WebSocket, WebRunnerConnection | undefined> = new Map<WebSocket, WebRunnerConnection | undefined>();
......@@ -41,6 +42,18 @@ export class WebCommandChannel extends CommandChannel {
// do nothing
}
// Set WebCommandChannel as singleton mode, one experiment could only start one webCommandChannel instance
private constructor(commandEmitter: EventEmitter) {
super(commandEmitter);
}
public static getInstance(commandEmitter: EventEmitter): CommandChannel {
if (!this.commandChannel) {
this.commandChannel = new WebCommandChannel(commandEmitter);
}
return this.commandChannel;
}
public async start(): Promise<void> {
const port = getBasePort() + 1;
this.webSocketServer = new SocketServer({ port });
......
......@@ -3,12 +3,12 @@
'use strict';
import { EventEmitter } from "events";
import { getLogger, Logger } from "../../common/log";
import { TrialJobStatus } from "../../common/trainingService";
import { GPUInfo } from "../../training_service/common/gpuData";
import { WebCommandChannel } from "./channels/webCommandChannel";
import { CommandChannel } from "./commandChannel";
import { WebCommandChannel } from './channels/webCommandChannel';
import { EventEmitter } from "events";
export type EnvironmentStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED';
......@@ -75,6 +75,8 @@ export class EnvironmentInformation {
public maxTrialNumberPerGpu?: number;
public useActiveGpu?: boolean;
public environmentService?: EnvironmentService;
constructor(id: string, name: string, envId?: string) {
this.log = getLogger();
this.id = id;
......@@ -127,6 +129,8 @@ export abstract class EnvironmentService {
public abstract refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void>;
public abstract stopEnvironment(environment: EnvironmentInformation): Promise<void>;
public abstract startEnvironment(environment: EnvironmentInformation): Promise<void>;
// Make public for ut
protected commandChannel: CommandChannel | undefined;
// It is used to set prefetched environment count, default value is 0 for OpenPAI and AML mode,
// in remote mode, this value is set to the length of machine list.
......@@ -134,6 +138,20 @@ export abstract class EnvironmentService {
return 0;
}
public abstract get getName(): string;
// Initialize command channel, use WebCommandChannel as default command channel
public initCommandChannel(eventEmitter: EventEmitter): void {
this.commandChannel = WebCommandChannel.getInstance(eventEmitter);
}
public get getCommandChannel(): CommandChannel {
if (this.commandChannel === undefined) {
throw new Error("Command channel not initialized!");
}
return this.commandChannel;
}
// It depends on environment pressure and settings
// for example, OpenPAI relies on API calls, and there is an limitation for frequence, so it need to be bigger.
public get environmentMaintenceLoopInterval(): number {
......@@ -147,10 +165,6 @@ export abstract class EnvironmentService {
return true;
}
public createCommandChannel(commandEmitter: EventEmitter): CommandChannel {
return new WebCommandChannel(commandEmitter);
}
public createEnvironmentInformation(envId: string, envName: string): EnvironmentInformation {
return new EnvironmentInformation(envId, envName);
}
......
......@@ -3,7 +3,6 @@
'use strict';
import { EventEmitter } from "events";
import * as fs from 'fs';
import * as path from 'path';
import * as component from '../../../common/component';
......@@ -14,13 +13,13 @@ import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey';
import { validateCodeDir } from '../../common/util';
import { AMLClient } from '../aml/amlClient';
import { AMLClusterConfig, AMLEnvironmentInformation, AMLTrialConfig } from '../aml/amlConfig';
import { AMLCommandChannel } from '../channels/amlCommandChannel';
import { CommandChannel } from "../commandChannel";
import { EnvironmentInformation, EnvironmentService } from '../environment';
import { EventEmitter } from "events";
import { AMLCommandChannel } from '../channels/amlCommandChannel';
/**
* Collector PAI jobs info from PAI cluster, and update pai job status locally
* Collector AML jobs info from AML cluster, and update aml job status locally
*/
@component.Singleton
export class AMLEnvironmentService extends EnvironmentService {
......@@ -41,14 +40,18 @@ export class AMLEnvironmentService extends EnvironmentService {
return false;
}
public createCommandChannel(commandEmitter: EventEmitter): CommandChannel {
return new AMLCommandChannel(commandEmitter);
public initCommandChannel(eventEmitter: EventEmitter): void {
this.commandChannel = new AMLCommandChannel(eventEmitter);
}
public createEnvironmentInformation(envId: string, envName: string): EnvironmentInformation {
return new AMLEnvironmentInformation(envId, envName);
}
public get getName(): string {
return 'aml';
}
public async config(key: string, value: string): Promise<void> {
switch (key) {
case TrialConfigMetadataKey.AML_CLUSTER_CONFIG:
......
import { AMLEnvironmentService } from './amlEnvironmentService';
import { OpenPaiEnvironmentService } from './openPaiEnvironmentService';
import { LocalEnvironmentService } from './localEnvironmentService';
import { RemoteEnvironmentService } from './remoteEnvironmentService';
import { EnvironmentService } from '../environment';
export class EnvironmentServiceFactory {
public static createEnvironmentService(name: string): EnvironmentService {
switch(name) {
case 'local':
return new LocalEnvironmentService();
case 'remote':
return new RemoteEnvironmentService();
case 'aml':
return new AMLEnvironmentService();
case 'pai':
return new OpenPaiEnvironmentService();
default:
throw new Error(`${name} not supported!`);
}
}
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as fs from 'fs';
import * as path from 'path';
import * as tkill from 'tree-kill';
import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../../common/log';
import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey';
import { EnvironmentInformation, EnvironmentService } from '../environment';
import { TrialConfig } from '../../common/trialConfig';
import { getExperimentRootDir, isAlive } from '../../../common/utils';
import { execMkdir, runScript, execCopydir } from '../../common/util';
@component.Singleton
export class LocalEnvironmentService extends EnvironmentService {
private readonly log: Logger = getLogger();
private localTrialConfig: TrialConfig | undefined;
private experimentRootDir: string;
private experimentId: string;
constructor() {
super();
this.experimentId = getExperimentId();
this.experimentRootDir = getExperimentRootDir();
}
public get environmentMaintenceLoopInterval(): number {
return 100;
}
public get hasStorageService(): boolean {
return false;
}
public get getName(): string {
return 'local';
}
public async config(key: string, value: string): Promise<void> {
switch (key) {
case TrialConfigMetadataKey.TRIAL_CONFIG:
this.localTrialConfig = <TrialConfig>JSON.parse(value);
break;
default:
this.log.debug(`Local mode does not proccess metadata key: '${key}', value: '${value}'`);
}
}
public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
environments.forEach(async (environment) => {
const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`;
const runnerReturnCodeFilePath: string = `${environment.runnerWorkingFolder}/code`;
/* eslint-disable require-atomic-updates */
try {
// check if pid file exist
const pidExist = await fs.existsSync(jobpidPath);
if (!pidExist) {
return;
}
const pid: string = await fs.promises.readFile(jobpidPath, 'utf8');
const alive: boolean = await isAlive(pid);
environment.status = 'RUNNING';
// if the process of jobpid is not alive any more
if (!alive) {
if (fs.existsSync(runnerReturnCodeFilePath)) {
const runnerReturnCode: string = await fs.promises.readFile(runnerReturnCodeFilePath, 'utf8');
const match: RegExpMatchArray | null = runnerReturnCode.trim()
.match(/^-?(\d+)\s+(\d+)$/);
if (match !== null) {
const { 1: code } = match;
// Update trial job's status based on result code
if (parseInt(code, 10) === 0) {
environment.setStatus('SUCCEEDED');
} else {
environment.setStatus('FAILED');
}
}
}
}
} catch (error) {
this.log.error(`Update job status exception, error is ${error.message}`);
}
});
}
public async startEnvironment(environment: EnvironmentInformation): Promise<void> {
if (this.localTrialConfig === undefined) {
throw new Error('Local trial config is not initialized');
}
// Need refactor, this temp folder path is not appropriate, there are two expId in this path
const localTempFolder: string = path.join(this.experimentRootDir, this.experimentId,
"environment-temp", "envs");
const localEnvCodeFolder: string = path.join(this.experimentRootDir, "envs");
environment.runnerWorkingFolder = path.join(localEnvCodeFolder, environment.id);
await execMkdir(environment.runnerWorkingFolder);
await execCopydir(localTempFolder, localEnvCodeFolder);
environment.command = `cd ${this.experimentRootDir} && \
${environment.command} --job_pid_file ${environment.runnerWorkingFolder}/pid \
1>${environment.runnerWorkingFolder}/trialrunner_stdout 2>${environment.runnerWorkingFolder}/trialrunner_stderr \
&& echo $? \`date +%s%3N\` >${environment.runnerWorkingFolder}/code`;
await fs.promises.writeFile(path.join(localEnvCodeFolder, 'nni_run.sh'),
environment.command, { encoding: 'utf8', mode: 0o777 }),
// Execute command in local machine
runScript(path.join(localEnvCodeFolder, 'nni_run.sh'));
environment.trackingUrl = `${environment.runnerWorkingFolder}`;
}
public async stopEnvironment(environment: EnvironmentInformation): Promise<void> {
const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`;
const pid: string = await fs.promises.readFile(jobpidPath, 'utf8');
tkill(Number(pid), 'SIGKILL');
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment