Commit df6145a2 authored by Yuge Zhang's avatar Yuge Zhang
Browse files

Merge branch 'master' of https://github.com/microsoft/nni into dev-retiarii

parents 0f0c6288 f8424a9f
...@@ -54,6 +54,7 @@ def update_experiment(): ...@@ -54,6 +54,7 @@ def update_experiment():
rest_pid = nni_config.get_config('restServerPid') rest_pid = nni_config.get_config('restServerPid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
experiment_config.update_experiment(key, 'status', 'STOPPED') experiment_config.update_experiment(key, 'status', 'STOPPED')
experiment_config.update_experiment(key, 'port', None)
continue continue
def check_experiment_id(args, update=True): def check_experiment_id(args, update=True):
......
...@@ -43,8 +43,8 @@ def get_registered_algo_meta(builtin_name, algo_type=None): ...@@ -43,8 +43,8 @@ def get_registered_algo_meta(builtin_name, algo_type=None):
------- -------
Returns meta information of speicified builtin alogorithms, for example: Returns meta information of speicified builtin alogorithms, for example:
{ {
'classArgsValidator': 'nni.smac_tuner.smac_tuner.SMACClassArgsValidator', 'classArgsValidator': 'nni.smac_tuner.SMACClassArgsValidator',
'className': 'nni.smac_tuner.smac_tuner.SMACTuner', 'className': 'nni.smac_tuner.SMACTuner',
'builtinName': 'SMAC' 'builtinName': 'SMAC'
} }
""" """
......
...@@ -25,7 +25,6 @@ def main_loop(args): ...@@ -25,7 +25,6 @@ def main_loop(args):
'''main loop logic for trial runner''' '''main loop logic for trial runner'''
idle_last_time = datetime.now() idle_last_time = datetime.now()
gpu_refresh_last_time = datetime.now() - timedelta(minutes=1) gpu_refresh_last_time = datetime.now() - timedelta(minutes=1)
try: try:
if args.job_pid_file: if args.job_pid_file:
with open(args.job_pid_file, 'w') as job_file: with open(args.job_pid_file, 'w') as job_file:
...@@ -188,6 +187,7 @@ if __name__ == '__main__': ...@@ -188,6 +187,7 @@ if __name__ == '__main__':
os.environ['NNI_EXP_ID'] = args.exp_id os.environ['NNI_EXP_ID'] = args.exp_id
os.environ['MULTI_PHASE'] = "true" os.environ['MULTI_PHASE'] = "true"
os.environ['NNI_TRIAL_JOB_ID'] = "runner" os.environ['NNI_TRIAL_JOB_ID'] = "runner"
os.environ['REUSE_MODE'] = "true"
from .log_utils import LogType, RemoteLogger, StdOutputType, nni_log from .log_utils import LogType, RemoteLogger, StdOutputType, nni_log
from .trial import Trial from .trial import Trial
......
...@@ -11,22 +11,21 @@ import sys ...@@ -11,22 +11,21 @@ import sys
from collections import deque from collections import deque
from unittest import TestCase, main from unittest import TestCase, main
from nni.algorithms.hpo.batch_tuner.batch_tuner import BatchTuner from nni.algorithms.hpo.batch_tuner import BatchTuner
from nni.algorithms.hpo.evolution_tuner.evolution_tuner import EvolutionTuner from nni.algorithms.hpo.evolution_tuner import EvolutionTuner
from nni.algorithms.hpo.gp_tuner.gp_tuner import GPTuner from nni.algorithms.hpo.gp_tuner import GPTuner
from nni.algorithms.hpo.gridsearch_tuner.gridsearch_tuner import GridSearchTuner from nni.algorithms.hpo.gridsearch_tuner import GridSearchTuner
from nni.algorithms.hpo.hyperopt_tuner.hyperopt_tuner import HyperoptTuner from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
from nni.algorithms.hpo.metis_tuner.metis_tuner import MetisTuner from nni.algorithms.hpo.metis_tuner import MetisTuner
from nni.algorithms.hpo.pbt_tuner.pbt_tuner import PBTTuner from nni.algorithms.hpo.pbt_tuner import PBTTuner
from nni.algorithms.hpo.regularized_evolution_tuner.regularized_evolution_tuner import RegularizedEvolutionTuner from nni.algorithms.hpo.regularized_evolution_tuner import RegularizedEvolutionTuner
from nni.runtime.msg_dispatcher import _pack_parameter, MsgDispatcher from nni.runtime.msg_dispatcher import _pack_parameter, MsgDispatcher
if sys.platform != 'win32': if sys.platform != 'win32':
from nni.algorithms.hpo.smac_tuner.smac_tuner import SMACTuner from nni.algorithms.hpo.smac_tuner import SMACTuner
from nni.tuner import Tuner from nni.tuner import Tuner
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('test_tuner') logger = logging.getLogger('test_tuner')
......
...@@ -9,7 +9,7 @@ from unittest import TestCase, main ...@@ -9,7 +9,7 @@ from unittest import TestCase, main
import hyperopt as hp import hyperopt as hp
from nni.algorithms.hpo.hyperopt_tuner.hyperopt_tuner import json2space, json2parameter, json2vals, HyperoptTuner from nni.algorithms.hpo.hyperopt_tuner import json2space, json2parameter, json2vals, HyperoptTuner
class HyperoptTunerTestCase(TestCase): class HyperoptTunerTestCase(TestCase):
......
...@@ -6,6 +6,7 @@ from unittest import TestCase, main ...@@ -6,6 +6,7 @@ from unittest import TestCase, main
from copy import deepcopy from copy import deepcopy
import torch import torch
from nni.algorithms.hpo.networkmorphism_tuner import NetworkMorphismTuner
from nni.algorithms.hpo.networkmorphism_tuner.graph import graph_to_json, json_to_graph from nni.algorithms.hpo.networkmorphism_tuner.graph import graph_to_json, json_to_graph
from nni.algorithms.hpo.networkmorphism_tuner.graph_transformer import ( from nni.algorithms.hpo.networkmorphism_tuner.graph_transformer import (
to_deeper_graph, to_deeper_graph,
...@@ -13,7 +14,6 @@ from nni.algorithms.hpo.networkmorphism_tuner.graph_transformer import ( ...@@ -13,7 +14,6 @@ from nni.algorithms.hpo.networkmorphism_tuner.graph_transformer import (
to_wider_graph, to_wider_graph,
) )
from nni.algorithms.hpo.networkmorphism_tuner.layers import layer_description_extractor from nni.algorithms.hpo.networkmorphism_tuner.layers import layer_description_extractor
from nni.algorithms.hpo.networkmorphism_tuner.networkmorphism_tuner import NetworkMorphismTuner
from nni.algorithms.hpo.networkmorphism_tuner.nn import CnnGenerator from nni.algorithms.hpo.networkmorphism_tuner.nn import CnnGenerator
......
...@@ -77,7 +77,11 @@ class NNIExperimentsManager implements ExperimentManager { ...@@ -77,7 +77,11 @@ class NNIExperimentsManager implements ExperimentManager {
this.withLockSync(() => { this.withLockSync(() => {
const experimentsInformation = JSON.parse(fs.readFileSync(this.experimentsPath).toString()); const experimentsInformation = JSON.parse(fs.readFileSync(this.experimentsPath).toString());
assert(experimentId in experimentsInformation, `Experiment Manager: Experiment Id ${experimentId} not found, this should not happen`); assert(experimentId in experimentsInformation, `Experiment Manager: Experiment Id ${experimentId} not found, this should not happen`);
experimentsInformation[experimentId][key] = value; if (value !== undefined) {
experimentsInformation[experimentId][key] = value;
} else {
delete experimentsInformation[experimentId][key];
}
fs.writeFileSync(this.experimentsPath, JSON.stringify(experimentsInformation, null, 4)); fs.writeFileSync(this.experimentsPath, JSON.stringify(experimentsInformation, null, 4));
}); });
} catch (err) { } catch (err) {
...@@ -128,6 +132,7 @@ class NNIExperimentsManager implements ExperimentManager { ...@@ -128,6 +132,7 @@ class NNIExperimentsManager implements ExperimentManager {
updateList.forEach((expId: string) => { updateList.forEach((expId: string) => {
if (experimentsInformation[expId]) { if (experimentsInformation[expId]) {
experimentsInformation[expId]['status'] = 'STOPPED'; experimentsInformation[expId]['status'] = 'STOPPED';
delete experimentsInformation[expId]['port'];
} else { } else {
this.log.error(`Experiment Manager: Experiment Id ${expId} not found, this should not happen`); this.log.error(`Experiment Manager: Experiment Id ${expId} not found, this should not happen`);
} }
......
...@@ -480,6 +480,7 @@ class NNIManager implements Manager { ...@@ -480,6 +480,7 @@ class NNIManager implements Manager {
} }
await this.storeExperimentProfile(); await this.storeExperimentProfile();
this.setStatus('STOPPED'); this.setStatus('STOPPED');
this.experimentManager.setExperimentInfo(this.experimentProfile.id, 'port', undefined);
} }
private async periodicallyUpdateExecDuration(): Promise<void> { private async periodicallyUpdateExecDuration(): Promise<void> {
......
...@@ -28,6 +28,7 @@ import { RouterTrainingService } from './training_service/reusable/routerTrainin ...@@ -28,6 +28,7 @@ import { RouterTrainingService } from './training_service/reusable/routerTrainin
import { PAIYarnTrainingService } from './training_service/pai/paiYarn/paiYarnTrainingService'; import { PAIYarnTrainingService } from './training_service/pai/paiYarn/paiYarnTrainingService';
import { DLTSTrainingService } from './training_service/dlts/dltsTrainingService'; import { DLTSTrainingService } from './training_service/dlts/dltsTrainingService';
function initStartupInfo( function initStartupInfo(
startExpMode: string, experimentId: string, basePort: number, platform: string, startExpMode: string, experimentId: string, basePort: number, platform: string,
logDirectory: string, experimentLogLevel: string, readonly: boolean, dispatcherPipe: string): void { logDirectory: string, experimentLogLevel: string, readonly: boolean, dispatcherPipe: string): void {
...@@ -36,22 +37,15 @@ function initStartupInfo( ...@@ -36,22 +37,15 @@ function initStartupInfo(
} }
async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise<void> { async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise<void> {
if (platformMode === 'adl') { const routerPlatformMode = ['remote', 'pai', 'aml', 'heterogeneous'];
if (routerPlatformMode.includes(platformMode)) {
Container.bind(TrainingService) Container.bind(TrainingService)
.to(AdlTrainingService) .to(RouterTrainingService)
.scope(Scope.Singleton); .scope(Scope.Singleton);
} else if (platformMode === 'local') { } else if (platformMode === 'local') {
Container.bind(TrainingService) Container.bind(TrainingService)
.to(LocalTrainingService) .to(LocalTrainingService)
.scope(Scope.Singleton); .scope(Scope.Singleton);
} else if (platformMode === 'remote') {
Container.bind(TrainingService)
.to(RouterTrainingService)
.scope(Scope.Singleton);
} else if (platformMode === 'pai') {
Container.bind(TrainingService)
.to(RouterTrainingService)
.scope(Scope.Singleton);
} else if (platformMode === 'paiYarn') { } else if (platformMode === 'paiYarn') {
Container.bind(TrainingService) Container.bind(TrainingService)
.to(PAIYarnTrainingService) .to(PAIYarnTrainingService)
...@@ -68,9 +62,9 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN ...@@ -68,9 +62,9 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
Container.bind(TrainingService) Container.bind(TrainingService)
.to(DLTSTrainingService) .to(DLTSTrainingService)
.scope(Scope.Singleton); .scope(Scope.Singleton);
} else if (platformMode === 'aml') { } else if (platformMode === 'adl') {
Container.bind(TrainingService) Container.bind(TrainingService)
.to(RouterTrainingService) .to(AdlTrainingService)
.scope(Scope.Singleton); .scope(Scope.Singleton);
} else { } else {
throw new Error(`Error: unsupported mode: ${platformMode}`); throw new Error(`Error: unsupported mode: ${platformMode}`);
...@@ -103,7 +97,7 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN ...@@ -103,7 +97,7 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
function usage(): void { function usage(): void {
console.info('usage: node main.js --port <port> --mode \ console.info('usage: node main.js --port <port> --mode \
<adl/local/remote/pai/kubeflow/frameworkcontroller/paiYarn/aml> --start_mode <new/resume> --experiment_id <id> --foreground <true/false>'); <local/remote/pai/kubeflow/frameworkcontroller/paiYarn/aml/adl/heterogeneous> --start_mode <new/resume> --experiment_id <id> --foreground <true/false>');
} }
const strPort: string = parseArg(['--port', '-p']); const strPort: string = parseArg(['--port', '-p']);
...@@ -123,7 +117,7 @@ const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : fals ...@@ -123,7 +117,7 @@ const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : fals
const port: number = parseInt(strPort, 10); const port: number = parseInt(strPort, 10);
const mode: string = parseArg(['--mode', '-m']); const mode: string = parseArg(['--mode', '-m']);
if (!['adl', 'local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'].includes(mode)) { if (!['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml', 'adl', 'heterogeneous'].includes(mode)) {
console.log(`FATAL: unknown mode: ${mode}`); console.log(`FATAL: unknown mode: ${mode}`);
usage(); usage();
process.exit(1); process.exit(1);
......
...@@ -23,7 +23,8 @@ export namespace ValidationSchemas { ...@@ -23,7 +23,8 @@ export namespace ValidationSchemas {
local_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase local_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
gpuIndices: joi.string(), gpuIndices: joi.string(),
maxTrialNumPerGpu: joi.number(), maxTrialNumPerGpu: joi.number(),
useActiveGpu: joi.boolean() useActiveGpu: joi.boolean(),
reuse: joi.boolean()
}), }),
trial_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase trial_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
image: joi.string().min(1), image: joi.string().min(1),
...@@ -182,6 +183,9 @@ export namespace ValidationSchemas { ...@@ -182,6 +183,9 @@ export namespace ValidationSchemas {
maxTrialNumPerGpu: joi.number(), maxTrialNumPerGpu: joi.number(),
useActiveGpu: joi.boolean() useActiveGpu: joi.boolean()
}), }),
heterogeneous_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
trainingServicePlatforms: joi.array(),
}),
nni_manager_ip: joi.object({ // eslint-disable-line @typescript-eslint/camelcase nni_manager_ip: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
nniManagerIp: joi.string().min(1) nniManagerIp: joi.string().min(1)
}), }),
......
...@@ -11,6 +11,7 @@ export enum TrialConfigMetadataKey { ...@@ -11,6 +11,7 @@ export enum TrialConfigMetadataKey {
LOCAL_CONFIG = 'local_config', LOCAL_CONFIG = 'local_config',
TRIAL_CONFIG = 'trial_config', TRIAL_CONFIG = 'trial_config',
REMOTE_CONFIG = 'remote_config', REMOTE_CONFIG = 'remote_config',
HETEROGENEOUS_CONFIG = 'heterogeneous_config',
EXPERIMENT_ID = 'experimentId', EXPERIMENT_ID = 'experimentId',
MULTI_PHASE = 'multiPhase', MULTI_PHASE = 'multiPhase',
RANDOM_SCHEDULER = 'random_scheduler', RANDOM_SCHEDULER = 'random_scheduler',
...@@ -22,5 +23,8 @@ export enum TrialConfigMetadataKey { ...@@ -22,5 +23,8 @@ export enum TrialConfigMetadataKey {
DLTS_CLUSTER_CONFIG = 'dlts_config', DLTS_CLUSTER_CONFIG = 'dlts_config',
AML_CLUSTER_CONFIG = 'aml_config', AML_CLUSTER_CONFIG = 'aml_config',
VERSION_CHECK = 'version_check', VERSION_CHECK = 'version_check',
LOG_COLLECTION = 'log_collection' LOG_COLLECTION = 'log_collection',
// Used to set platform for heterogeneous in reuse mode,
// temproarily change and will refactor config schema in the future
PLATFORM_LIST = 'platform_list'
} }
...@@ -78,7 +78,7 @@ class LocalTrialJobDetail implements TrialJobDetail { ...@@ -78,7 +78,7 @@ class LocalTrialJobDetail implements TrialJobDetail {
/** /**
* Local training service config * Local training service config
*/ */
class LocalConfig { export class LocalConfig {
public maxTrialNumPerGpu?: number; public maxTrialNumPerGpu?: number;
public gpuIndices?: string; public gpuIndices?: string;
public useActiveGpu?: boolean; public useActiveGpu?: boolean;
...@@ -253,7 +253,20 @@ class LocalTrainingService implements TrainingService { ...@@ -253,7 +253,20 @@ class LocalTrainingService implements TrainingService {
return Promise.resolve(); return Promise.resolve();
} }
tkill(trialJob.pid, 'SIGKILL'); tkill(trialJob.pid, 'SIGTERM');
const startTime = Date.now();
while(await isAlive(trialJob.pid)) {
if (Date.now() - startTime > 4999) {
tkill(trialJob.pid, 'SIGKILL', (err) => {
if (err) {
this.log.error(`kill trial job error: ${err}`);
}
});
break;
}
await delay(500);
}
this.setTrialJobStatus(trialJob, getJobCancelStatus(isEarlyStopped)); this.setTrialJobStatus(trialJob, getJobCancelStatus(isEarlyStopped));
return Promise.resolve(); return Promise.resolve();
......
...@@ -358,6 +358,10 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -358,6 +358,10 @@ class RemoteMachineTrainingService implements TrainingService {
case TrialConfigMetadataKey.LOG_COLLECTION: case TrialConfigMetadataKey.LOG_COLLECTION:
this.logCollection = value; this.logCollection = value;
break; break;
case TrialConfigMetadataKey.REMOTE_CONFIG:
// Add remote_config in remoteEnvironmentService to set reuse mode,
// this config need to be catched here, otherwise will throw Unknown key exception here
break;
default: default:
//Reject for unknown keys //Reject for unknown keys
throw new Error(`Uknown key: ${key}`); throw new Error(`Uknown key: ${key}`);
......
...@@ -8,6 +8,7 @@ import { getBasePort, getExperimentId } from "../../../common/experimentStartupI ...@@ -8,6 +8,7 @@ import { getBasePort, getExperimentId } from "../../../common/experimentStartupI
import { INITIALIZED } from '../../../core/commands'; import { INITIALIZED } from '../../../core/commands';
import { CommandChannel, RunnerConnection } from "../commandChannel"; import { CommandChannel, RunnerConnection } from "../commandChannel";
import { Channel, EnvironmentInformation } from "../environment"; import { Channel, EnvironmentInformation } from "../environment";
import { EventEmitter } from "events";
class WebRunnerConnection extends RunnerConnection { class WebRunnerConnection extends RunnerConnection {
public readonly clients: WebSocket[] = []; public readonly clients: WebSocket[] = [];
...@@ -29,7 +30,7 @@ class WebRunnerConnection extends RunnerConnection { ...@@ -29,7 +30,7 @@ class WebRunnerConnection extends RunnerConnection {
export class WebCommandChannel extends CommandChannel { export class WebCommandChannel extends CommandChannel {
private readonly expId: string = getExperimentId(); private readonly expId: string = getExperimentId();
private static commandChannel: WebCommandChannel;
private webSocketServer: SocketServer | undefined; private webSocketServer: SocketServer | undefined;
private clients: Map<WebSocket, WebRunnerConnection | undefined> = new Map<WebSocket, WebRunnerConnection | undefined>(); private clients: Map<WebSocket, WebRunnerConnection | undefined> = new Map<WebSocket, WebRunnerConnection | undefined>();
...@@ -40,6 +41,18 @@ export class WebCommandChannel extends CommandChannel { ...@@ -40,6 +41,18 @@ export class WebCommandChannel extends CommandChannel {
public async config(_key: string, _value: any): Promise<void> { public async config(_key: string, _value: any): Promise<void> {
// do nothing // do nothing
} }
// Set WebCommandChannel as singleton mode, one experiment could only start one webCommandChannel instance
private constructor(commandEmitter: EventEmitter) {
super(commandEmitter);
}
public static getInstance(commandEmitter: EventEmitter): CommandChannel {
if (!this.commandChannel) {
this.commandChannel = new WebCommandChannel(commandEmitter);
}
return this.commandChannel;
}
public async start(): Promise<void> { public async start(): Promise<void> {
const port = getBasePort() + 1; const port = getBasePort() + 1;
......
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
'use strict'; 'use strict';
import { EventEmitter } from "events";
import { getLogger, Logger } from "../../common/log"; import { getLogger, Logger } from "../../common/log";
import { TrialJobStatus } from "../../common/trainingService"; import { TrialJobStatus } from "../../common/trainingService";
import { GPUInfo } from "../../training_service/common/gpuData"; import { GPUInfo } from "../../training_service/common/gpuData";
import { WebCommandChannel } from "./channels/webCommandChannel";
import { CommandChannel } from "./commandChannel"; import { CommandChannel } from "./commandChannel";
import { WebCommandChannel } from './channels/webCommandChannel';
import { EventEmitter } from "events";
export type EnvironmentStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED'; export type EnvironmentStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED';
...@@ -75,6 +75,8 @@ export class EnvironmentInformation { ...@@ -75,6 +75,8 @@ export class EnvironmentInformation {
public maxTrialNumberPerGpu?: number; public maxTrialNumberPerGpu?: number;
public useActiveGpu?: boolean; public useActiveGpu?: boolean;
public environmentService?: EnvironmentService;
constructor(id: string, name: string, envId?: string) { constructor(id: string, name: string, envId?: string) {
this.log = getLogger(); this.log = getLogger();
this.id = id; this.id = id;
...@@ -127,6 +129,8 @@ export abstract class EnvironmentService { ...@@ -127,6 +129,8 @@ export abstract class EnvironmentService {
public abstract refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void>; public abstract refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void>;
public abstract stopEnvironment(environment: EnvironmentInformation): Promise<void>; public abstract stopEnvironment(environment: EnvironmentInformation): Promise<void>;
public abstract startEnvironment(environment: EnvironmentInformation): Promise<void>; public abstract startEnvironment(environment: EnvironmentInformation): Promise<void>;
// Make public for ut
protected commandChannel: CommandChannel | undefined;
// It is used to set prefetched environment count, default value is 0 for OpenPAI and AML mode, // It is used to set prefetched environment count, default value is 0 for OpenPAI and AML mode,
// in remote mode, this value is set to the length of machine list. // in remote mode, this value is set to the length of machine list.
...@@ -134,6 +138,20 @@ export abstract class EnvironmentService { ...@@ -134,6 +138,20 @@ export abstract class EnvironmentService {
return 0; return 0;
} }
public abstract get getName(): string;
// Initialize command channel, use WebCommandChannel as default command channel
public initCommandChannel(eventEmitter: EventEmitter): void {
this.commandChannel = WebCommandChannel.getInstance(eventEmitter);
}
public get getCommandChannel(): CommandChannel {
if (this.commandChannel === undefined) {
throw new Error("Command channel not initialized!");
}
return this.commandChannel;
}
// It depends on environment pressure and settings // It depends on environment pressure and settings
// for example, OpenPAI relies on API calls, and there is an limitation for frequence, so it need to be bigger. // for example, OpenPAI relies on API calls, and there is an limitation for frequence, so it need to be bigger.
public get environmentMaintenceLoopInterval(): number { public get environmentMaintenceLoopInterval(): number {
...@@ -147,10 +165,6 @@ export abstract class EnvironmentService { ...@@ -147,10 +165,6 @@ export abstract class EnvironmentService {
return true; return true;
} }
public createCommandChannel(commandEmitter: EventEmitter): CommandChannel {
return new WebCommandChannel(commandEmitter);
}
public createEnvironmentInformation(envId: string, envName: string): EnvironmentInformation { public createEnvironmentInformation(envId: string, envName: string): EnvironmentInformation {
return new EnvironmentInformation(envId, envName); return new EnvironmentInformation(envId, envName);
} }
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
'use strict'; 'use strict';
import { EventEmitter } from "events";
import * as fs from 'fs'; import * as fs from 'fs';
import * as path from 'path'; import * as path from 'path';
import * as component from '../../../common/component'; import * as component from '../../../common/component';
...@@ -14,13 +13,13 @@ import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey'; ...@@ -14,13 +13,13 @@ import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey';
import { validateCodeDir } from '../../common/util'; import { validateCodeDir } from '../../common/util';
import { AMLClient } from '../aml/amlClient'; import { AMLClient } from '../aml/amlClient';
import { AMLClusterConfig, AMLEnvironmentInformation, AMLTrialConfig } from '../aml/amlConfig'; import { AMLClusterConfig, AMLEnvironmentInformation, AMLTrialConfig } from '../aml/amlConfig';
import { AMLCommandChannel } from '../channels/amlCommandChannel';
import { CommandChannel } from "../commandChannel";
import { EnvironmentInformation, EnvironmentService } from '../environment'; import { EnvironmentInformation, EnvironmentService } from '../environment';
import { EventEmitter } from "events";
import { AMLCommandChannel } from '../channels/amlCommandChannel';
/** /**
* Collector PAI jobs info from PAI cluster, and update pai job status locally * Collector AML jobs info from AML cluster, and update aml job status locally
*/ */
@component.Singleton @component.Singleton
export class AMLEnvironmentService extends EnvironmentService { export class AMLEnvironmentService extends EnvironmentService {
...@@ -41,14 +40,18 @@ export class AMLEnvironmentService extends EnvironmentService { ...@@ -41,14 +40,18 @@ export class AMLEnvironmentService extends EnvironmentService {
return false; return false;
} }
public createCommandChannel(commandEmitter: EventEmitter): CommandChannel { public initCommandChannel(eventEmitter: EventEmitter): void {
return new AMLCommandChannel(commandEmitter); this.commandChannel = new AMLCommandChannel(eventEmitter);
} }
public createEnvironmentInformation(envId: string, envName: string): EnvironmentInformation { public createEnvironmentInformation(envId: string, envName: string): EnvironmentInformation {
return new AMLEnvironmentInformation(envId, envName); return new AMLEnvironmentInformation(envId, envName);
} }
public get getName(): string {
return 'aml';
}
public async config(key: string, value: string): Promise<void> { public async config(key: string, value: string): Promise<void> {
switch (key) { switch (key) {
case TrialConfigMetadataKey.AML_CLUSTER_CONFIG: case TrialConfigMetadataKey.AML_CLUSTER_CONFIG:
......
import { AMLEnvironmentService } from './amlEnvironmentService';
import { OpenPaiEnvironmentService } from './openPaiEnvironmentService';
import { LocalEnvironmentService } from './localEnvironmentService';
import { RemoteEnvironmentService } from './remoteEnvironmentService';
import { EnvironmentService } from '../environment';
export class EnvironmentServiceFactory {
public static createEnvironmentService(name: string): EnvironmentService {
switch(name) {
case 'local':
return new LocalEnvironmentService();
case 'remote':
return new RemoteEnvironmentService();
case 'aml':
return new AMLEnvironmentService();
case 'pai':
return new OpenPaiEnvironmentService();
default:
throw new Error(`${name} not supported!`);
}
}
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as fs from 'fs';
import * as path from 'path';
import * as tkill from 'tree-kill';
import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../../common/log';
import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey';
import { EnvironmentInformation, EnvironmentService } from '../environment';
import { TrialConfig } from '../../common/trialConfig';
import { getExperimentRootDir, isAlive } from '../../../common/utils';
import { execMkdir, runScript, execCopydir } from '../../common/util';
@component.Singleton
export class LocalEnvironmentService extends EnvironmentService {
private readonly log: Logger = getLogger();
private localTrialConfig: TrialConfig | undefined;
private experimentRootDir: string;
private experimentId: string;
constructor() {
super();
this.experimentId = getExperimentId();
this.experimentRootDir = getExperimentRootDir();
}
public get environmentMaintenceLoopInterval(): number {
return 100;
}
public get hasStorageService(): boolean {
return false;
}
public get getName(): string {
return 'local';
}
public async config(key: string, value: string): Promise<void> {
switch (key) {
case TrialConfigMetadataKey.TRIAL_CONFIG:
this.localTrialConfig = <TrialConfig>JSON.parse(value);
break;
default:
this.log.debug(`Local mode does not proccess metadata key: '${key}', value: '${value}'`);
}
}
public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
environments.forEach(async (environment) => {
const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`;
const runnerReturnCodeFilePath: string = `${environment.runnerWorkingFolder}/code`;
/* eslint-disable require-atomic-updates */
try {
// check if pid file exist
const pidExist = await fs.existsSync(jobpidPath);
if (!pidExist) {
return;
}
const pid: string = await fs.promises.readFile(jobpidPath, 'utf8');
const alive: boolean = await isAlive(pid);
environment.status = 'RUNNING';
// if the process of jobpid is not alive any more
if (!alive) {
if (fs.existsSync(runnerReturnCodeFilePath)) {
const runnerReturnCode: string = await fs.promises.readFile(runnerReturnCodeFilePath, 'utf8');
const match: RegExpMatchArray | null = runnerReturnCode.trim()
.match(/^-?(\d+)\s+(\d+)$/);
if (match !== null) {
const { 1: code } = match;
// Update trial job's status based on result code
if (parseInt(code, 10) === 0) {
environment.setStatus('SUCCEEDED');
} else {
environment.setStatus('FAILED');
}
}
}
}
} catch (error) {
this.log.error(`Update job status exception, error is ${error.message}`);
}
});
}
public async startEnvironment(environment: EnvironmentInformation): Promise<void> {
if (this.localTrialConfig === undefined) {
throw new Error('Local trial config is not initialized');
}
// Need refactor, this temp folder path is not appropriate, there are two expId in this path
const localTempFolder: string = path.join(this.experimentRootDir, this.experimentId,
"environment-temp", "envs");
const localEnvCodeFolder: string = path.join(this.experimentRootDir, "envs");
environment.runnerWorkingFolder = path.join(localEnvCodeFolder, environment.id);
await execMkdir(environment.runnerWorkingFolder);
await execCopydir(localTempFolder, localEnvCodeFolder);
environment.command = `cd ${this.experimentRootDir} && \
${environment.command} --job_pid_file ${environment.runnerWorkingFolder}/pid \
1>${environment.runnerWorkingFolder}/trialrunner_stdout 2>${environment.runnerWorkingFolder}/trialrunner_stderr \
&& echo $? \`date +%s%3N\` >${environment.runnerWorkingFolder}/code`;
await fs.promises.writeFile(path.join(localEnvCodeFolder, 'nni_run.sh'),
environment.command, { encoding: 'utf8', mode: 0o777 }),
// Execute command in local machine
runScript(path.join(localEnvCodeFolder, 'nni_run.sh'));
environment.trackingUrl = `${environment.runnerWorkingFolder}`;
}
public async stopEnvironment(environment: EnvironmentInformation): Promise<void> {
const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`;
const pid: string = await fs.promises.readFile(jobpidPath, 'utf8');
tkill(Number(pid), 'SIGKILL');
}
}
...@@ -45,6 +45,10 @@ export class OpenPaiEnvironmentService extends EnvironmentService { ...@@ -45,6 +45,10 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
return true; return true;
} }
public get getName(): string {
return 'pai';
}
public async config(key: string, value: string): Promise<void> { public async config(key: string, value: string): Promise<void> {
switch (key) { switch (key) {
case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG: case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG:
......
...@@ -63,6 +63,10 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -63,6 +63,10 @@ export class RemoteEnvironmentService extends EnvironmentService {
return false; return false;
} }
public get getName(): string {
return 'remote';
}
public async config(key: string, value: string): Promise<void> { public async config(key: string, value: string): Promise<void> {
switch (key) { switch (key) {
case TrialConfigMetadataKey.MACHINE_LIST: case TrialConfigMetadataKey.MACHINE_LIST:
...@@ -134,7 +138,15 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -134,7 +138,15 @@ export class RemoteEnvironmentService extends EnvironmentService {
await executor.createFolder(remoteGpuScriptCollectorDir, true); await executor.createFolder(remoteGpuScriptCollectorDir, true);
await executor.allowPermission(true, nniRootDir); await executor.allowPermission(true, nniRootDir);
} }
public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
const tasks: Promise<void>[] = [];
environments.forEach(async (environment) => {
tasks.push(this.refreshEnvironment(environment));
});
await Promise.all(tasks);
}
private async refreshEnvironment(environment: EnvironmentInformation): Promise<void> { private async refreshEnvironment(environment: EnvironmentInformation): Promise<void> {
const executor = await this.getExecutor(environment.id); const executor = await this.getExecutor(environment.id);
const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`; const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`;
...@@ -176,14 +188,6 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -176,14 +188,6 @@ export class RemoteEnvironmentService extends EnvironmentService {
} }
} }
public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
const tasks: Promise<void>[] = [];
environments.forEach(async (environment) => {
tasks.push(this.refreshEnvironment(environment));
});
await Promise.all(tasks);
}
/** /**
* If a environment is finished, release the connection resource * If a environment is finished, release the connection resource
* @param environment remote machine environment job detail * @param environment remote machine environment job detail
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment