Unverified Commit cf6a1de6 authored by Chi Song's avatar Chi Song Committed by GitHub
Browse files

Refactor: remote machine isolate OS commands phase 1 (#2376)

To support Windows node in remote mode, this PR adds a layer of commands (osCommands) to deal difference between Windows and Unix-like OS. To share code, ShellExecutor is added to enrich original SshClient class.

I will implement windows version commands in next phase.

This pattern can be expanded to Local or other platform in future, so I moved related code to common folder for sharing.
parent edd3f8ac
...@@ -16,7 +16,7 @@ git clone https://github.com/Microsoft/nni.git ...@@ -16,7 +16,7 @@ git clone https://github.com/Microsoft/nni.git
to clone the source code to clone the source code
### 2. Prepare the debug environment and install dependencies** ### 2. Prepare the debug environment and install dependencies
Change directory to the source code folder, then run the command Change directory to the source code folder, then run the command
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import { OsCommands } from "../osCommands";
import { RemoteCommandResult } from "../remoteMachineData";
class LinuxCommands extends OsCommands {
public createFolder(folderName: string, sharedFolder: boolean = false): string {
let command;
if (sharedFolder) {
command = `umask 0; mkdir -p '${folderName}'`;
} else {
command = `mkdir -p '${folderName}'`;
}
return command;
}
public allowPermission(isRecursive: boolean = false, ...folders: string[]): string {
const folderString = folders.join("' '");
let command;
if (isRecursive) {
command = `chmod 777 -R '${folderString}'`;
} else {
command = `chmod 777 '${folderString}'`;
}
return command;
}
public removeFolder(folderName: string, isRecursive: boolean = false, isForce: boolean = true): string {
let flags = '';
if (isForce || isRecursive) {
flags = `-${isRecursive ? 'r' : 'd'}${isForce ? 'f' : ''} `;
}
const command = `rm ${flags}'${folderName}'`;
return command;
}
public removeFiles(folderName: string, filePattern: string): string {
const files = this.joinPath(folderName, filePattern);
const command = `rm '${files}'`;
return command;
}
public readLastLines(fileName: string, lineCount: number = 1): string {
const command = `tail -n ${lineCount} '${fileName}'`;
return command;
}
public isProcessAliveCommand(pidFileName: string): string {
const command = `kill -0 \`cat '${pidFileName}'\``;
return command;
}
public isProcessAliveProcessOutput(commandResult: RemoteCommandResult): boolean {
let result = true;
if (commandResult.exitCode !== 0) {
result = false;
}
return result;
}
public killChildProcesses(pidFileName: string): string {
const command = `pkill -P \`cat '${pidFileName}'\``;
return command;
}
public extractFile(tarFileName: string, targetFolder: string): string {
const command = `tar -oxzf '${tarFileName}' -C '${targetFolder}'`;
return command;
}
public executeScript(script: string, isFile: boolean): string {
let command: string;
if (isFile) {
command = `bash '${script}'`;
} else {
script = script.replace('"', '\\"');
command = `bash -c "${script}"`;
}
return command;
}
}
export { LinuxCommands };
...@@ -8,7 +8,7 @@ import { getLogger, Logger } from '../../common/log'; ...@@ -8,7 +8,7 @@ import { getLogger, Logger } from '../../common/log';
import { randomSelect } from '../../common/utils'; import { randomSelect } from '../../common/utils';
import { GPUInfo } from '../common/gpuData'; import { GPUInfo } from '../common/gpuData';
import { import {
parseGpuIndices, RemoteMachineMeta, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail, ScheduleResultType, SSHClientManager parseGpuIndices, RemoteMachineMeta, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail, ScheduleResultType, ExecutorManager
} from './remoteMachineData'; } from './remoteMachineData';
type SCHEDULE_POLICY_NAME = 'random' | 'round-robin'; type SCHEDULE_POLICY_NAME = 'random' | 'round-robin';
...@@ -18,7 +18,7 @@ type SCHEDULE_POLICY_NAME = 'random' | 'round-robin'; ...@@ -18,7 +18,7 @@ type SCHEDULE_POLICY_NAME = 'random' | 'round-robin';
*/ */
export class GPUScheduler { export class GPUScheduler {
private readonly machineSSHClientMap: Map<RemoteMachineMeta, SSHClientManager>; private readonly machineExecutorMap: Map<RemoteMachineMeta, ExecutorManager>;
private readonly log: Logger = getLogger(); private readonly log: Logger = getLogger();
private readonly policyName: SCHEDULE_POLICY_NAME = 'round-robin'; private readonly policyName: SCHEDULE_POLICY_NAME = 'round-robin';
private roundRobinIndex: number = 0; private roundRobinIndex: number = 0;
...@@ -26,12 +26,12 @@ export class GPUScheduler { ...@@ -26,12 +26,12 @@ export class GPUScheduler {
/** /**
* Constructor * Constructor
* @param machineSSHClientMap map from remote machine to sshClient * @param machineExecutorMap map from remote machine to executor
*/ */
constructor(machineSSHClientMap: Map<RemoteMachineMeta, SSHClientManager>) { constructor(machineExecutorMap: Map<RemoteMachineMeta, ExecutorManager>) {
assert(machineSSHClientMap.size > 0); assert(machineExecutorMap.size > 0);
this.machineSSHClientMap = machineSSHClientMap; this.machineExecutorMap = machineExecutorMap;
this.configuredRMs = Array.from(machineSSHClientMap.keys()); this.configuredRMs = Array.from(machineExecutorMap.keys());
} }
/** /**
...@@ -43,7 +43,7 @@ export class GPUScheduler { ...@@ -43,7 +43,7 @@ export class GPUScheduler {
requiredGPUNum = 0; requiredGPUNum = 0;
} }
assert(requiredGPUNum >= 0); assert(requiredGPUNum >= 0);
const allRMs: RemoteMachineMeta[] = Array.from(this.machineSSHClientMap.keys()); const allRMs: RemoteMachineMeta[] = Array.from(this.machineExecutorMap.keys());
assert(allRMs.length > 0); assert(allRMs.length > 0);
// Step 1: Check if required GPU number not exceeds the total GPU number in all machines // Step 1: Check if required GPU number not exceeds the total GPU number in all machines
...@@ -135,7 +135,7 @@ export class GPUScheduler { ...@@ -135,7 +135,7 @@ export class GPUScheduler {
*/ */
private gpuResourceDetection(): Map<RemoteMachineMeta, GPUInfo[]> { private gpuResourceDetection(): Map<RemoteMachineMeta, GPUInfo[]> {
const totalResourceMap: Map<RemoteMachineMeta, GPUInfo[]> = new Map<RemoteMachineMeta, GPUInfo[]>(); const totalResourceMap: Map<RemoteMachineMeta, GPUInfo[]> = new Map<RemoteMachineMeta, GPUInfo[]>();
this.machineSSHClientMap.forEach((sshClientManager: SSHClientManager, rmMeta: RemoteMachineMeta) => { this.machineExecutorMap.forEach((executorManager: ExecutorManager, rmMeta: RemoteMachineMeta) => {
// Assgin totoal GPU count as init available GPU number // Assgin totoal GPU count as init available GPU number
if (rmMeta.gpuSummary !== undefined) { if (rmMeta.gpuSummary !== undefined) {
const availableGPUs: GPUInfo[] = []; const availableGPUs: GPUInfo[] = [];
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import { RemoteCommandResult } from "./remoteMachineData";
abstract class OsCommands {
protected pathSpliter: string = '/';
protected multiplePathSpliter: RegExp = new RegExp(`\\${this.pathSpliter}{2,}`);
public abstract createFolder(folderName: string, sharedFolder: boolean): string;
public abstract allowPermission(isRecursive: boolean, ...folders: string[]): string;
public abstract removeFolder(folderName: string, isRecursive: boolean, isForce: boolean): string;
public abstract removeFiles(folderOrFileName: string, filePattern: string): string;
public abstract readLastLines(fileName: string, lineCount: number): string;
public abstract isProcessAliveCommand(pidFileName: string): string;
public abstract isProcessAliveProcessOutput(result: RemoteCommandResult): boolean;
public abstract killChildProcesses(pidFileName: string): string;
public abstract extractFile(tarFileName: string, targetFolder: string): string;
public abstract executeScript(script: string, isFile: boolean): string;
public joinPath(...paths: string[]): string {
let dir: string = paths.filter((path: any) => path !== '').join(this.pathSpliter);
if (dir === '') {
dir = '.';
} else {
dir = dir.replace(this.multiplePathSpliter, this.pathSpliter);
}
return dir;
}
}
export { OsCommands };
...@@ -3,11 +3,9 @@ ...@@ -3,11 +3,9 @@
'use strict'; 'use strict';
import * as fs from 'fs'; import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { Client, ConnectConfig } from 'ssh2';
import { Deferred } from 'ts-deferred';
import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { GPUInfo, GPUSummary } from '../common/gpuData'; import { GPUInfo, GPUSummary } from '../common/gpuData';
import { ShellExecutor } from './shellExecutor';
/** /**
* Metadata of remote machine for configuration and statuc query * Metadata of remote machine for configuration and statuc query
...@@ -72,7 +70,7 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail { ...@@ -72,7 +70,7 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail {
public gpuIndices: GPUInfo[]; public gpuIndices: GPUInfo[];
constructor(id: string, status: TrialJobStatus, submitTime: number, constructor(id: string, status: TrialJobStatus, submitTime: number,
workingDirectory: string, form: TrialJobApplicationForm) { workingDirectory: string, form: TrialJobApplicationForm) {
this.id = id; this.id = id;
this.status = status; this.status = status;
this.submitTime = submitTime; this.submitTime = submitTime;
...@@ -84,149 +82,88 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail { ...@@ -84,149 +82,88 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail {
} }
/** /**
* The remote machine ssh client used for trial and gpu detector * The remote machine executor manager
*/ */
export class SSHClient { export class ExecutorManager {
private readonly sshClient: Client; private readonly executorArray: ShellExecutor[];
private usedConnectionNumber: number; //count the connection number of every client
constructor(sshClient: Client, usedConnectionNumber: number) {
this.sshClient = sshClient;
this.usedConnectionNumber = usedConnectionNumber;
}
public get getSSHClientInstance(): Client {
return this.sshClient;
}
public get getUsedConnectionNumber(): number {
return this.usedConnectionNumber;
}
public addUsedConnectionNumber(): void {
this.usedConnectionNumber += 1;
}
public minusUsedConnectionNumber(): void {
this.usedConnectionNumber -= 1;
}
}
/**
* The remote machine ssh client manager
*/
export class SSHClientManager {
private readonly sshClientArray: SSHClient[];
private readonly maxTrialNumberPerConnection: number; private readonly maxTrialNumberPerConnection: number;
private readonly rmMeta: RemoteMachineMeta; private readonly rmMeta: RemoteMachineMeta;
constructor(sshClientArray: SSHClient[], maxTrialNumberPerConnection: number, rmMeta: RemoteMachineMeta) { constructor(executorArray: ShellExecutor[], maxTrialNumberPerConnection: number, rmMeta: RemoteMachineMeta) {
this.rmMeta = rmMeta; this.rmMeta = rmMeta;
this.sshClientArray = sshClientArray; this.executorArray = executorArray;
this.maxTrialNumberPerConnection = maxTrialNumberPerConnection; this.maxTrialNumberPerConnection = maxTrialNumberPerConnection;
} }
/** /**
* find a available ssh client in ssh array, if no ssh client available, return undefined * find a available executor, if no executor available, return a new one
*/ */
public async getAvailableSSHClient(): Promise<Client> { public async getAvailableExecutor(): Promise<ShellExecutor> {
const deferred: Deferred<Client> = new Deferred<Client>(); for (const index of this.executorArray.keys()) {
for (const index of this.sshClientArray.keys()) { const connectionNumber: number = this.executorArray[index].getUsedConnectionNumber;
const connectionNumber: number = this.sshClientArray[index].getUsedConnectionNumber;
if (connectionNumber < this.maxTrialNumberPerConnection) { if (connectionNumber < this.maxTrialNumberPerConnection) {
this.sshClientArray[index].addUsedConnectionNumber(); this.executorArray[index].addUsedConnectionNumber();
deferred.resolve(this.sshClientArray[index].getSSHClientInstance);
return deferred.promise; return this.executorArray[index];
} }
} }
//init a new ssh client if could not get an available one //init a new executor if could not get an available one
return this.initNewSSHClient(); return await this.initNewShellExecutor();
} }
/** /**
* add a new ssh client to sshClientArray * add a new executor to executorArray
* @param sshClient SSH Client * @param executor ShellExecutor
*/ */
public addNewSSHClient(client: Client): void { public addNewShellExecutor(executor: ShellExecutor): void {
this.sshClientArray.push(new SSHClient(client, 1)); this.executorArray.push(executor);
} }
/** /**
* first ssh client instance is used for gpu collector and host job * first executor instance is used for gpu collector and host job
*/ */
public getFirstSSHClient(): Client { public getFirstExecutor(): ShellExecutor {
return this.sshClientArray[0].getSSHClientInstance; return this.executorArray[0];
} }
/** /**
* close all of ssh client * close all of executor
*/ */
public closeAllSSHClient(): void { public closeAllExecutor(): void {
for (const sshClient of this.sshClientArray) { for (const executor of this.executorArray) {
sshClient.getSSHClientInstance.end(); executor.close();
} }
} }
/** /**
* retrieve resource, minus a number for given ssh client * retrieve resource, minus a number for given executor
* @param client SSH Client * @param executor executor
*/ */
public releaseConnection(client: Client | undefined): void { public releaseConnection(executor: ShellExecutor | undefined): void {
if (client === undefined) { if (executor === undefined) {
throw new Error(`could not release a undefined ssh client`); throw new Error(`could not release a undefined executor`);
} }
for (const index of this.sshClientArray.keys()) { for (const index of this.executorArray.keys()) {
if (this.sshClientArray[index].getSSHClientInstance === client) { if (this.executorArray[index] === executor) {
this.sshClientArray[index].minusUsedConnectionNumber(); this.executorArray[index].minusUsedConnectionNumber();
break; break;
} }
} }
} }
/** /**
* Create a new ssh connection client and initialize it * Create a new connection executor and initialize it
*/ */
private initNewSSHClient(): Promise<Client> { private async initNewShellExecutor(): Promise<ShellExecutor> {
const deferred: Deferred<Client> = new Deferred<Client>(); const executor = new ShellExecutor();
const conn: Client = new Client(); await executor.initialize(this.rmMeta);
const connectConfig: ConnectConfig = { return executor;
host: this.rmMeta.ip,
port: this.rmMeta.port,
username: this.rmMeta.username,
tryKeyboard: true };
if (this.rmMeta.passwd !== undefined) {
connectConfig.password = this.rmMeta.passwd;
} else if (this.rmMeta.sshKeyPath !== undefined) {
if (!fs.existsSync(this.rmMeta.sshKeyPath)) {
//SSh key path is not a valid file, reject
deferred.reject(new Error(`${this.rmMeta.sshKeyPath} does not exist.`));
}
const privateKey: string = fs.readFileSync(this.rmMeta.sshKeyPath, 'utf8');
connectConfig.privateKey = privateKey;
connectConfig.passphrase = this.rmMeta.passphrase;
} else {
deferred.reject(new Error(`No valid passwd or sshKeyPath is configed.`));
}
conn.on('ready', () => {
this.addNewSSHClient(conn);
deferred.resolve(conn);
})
.on('error', (err: Error) => {
// SSH connection error, reject with error message
deferred.reject(new Error(err.message));
}).on("keyboard-interactive", (name, instructions, lang, prompts, finish) => {
finish([this.rmMeta.passwd]);
})
.connect(connectConfig);
return deferred.promise;
} }
} }
export type RemoteMachineScheduleResult = { scheduleInfo: RemoteMachineScheduleInfo | undefined; resultType: ScheduleResultType}; export type RemoteMachineScheduleResult = { scheduleInfo: RemoteMachineScheduleInfo | undefined; resultType: ScheduleResultType };
export type RemoteMachineScheduleInfo = { rmMeta: RemoteMachineMeta; cudaVisibleDevice: string}; export type RemoteMachineScheduleInfo = { rmMeta: RemoteMachineMeta; cudaVisibleDevice: string };
export enum ScheduleResultType { export enum ScheduleResultType {
// Schedule succeeded // Schedule succeeded
...@@ -240,7 +177,7 @@ export enum ScheduleResultType { ...@@ -240,7 +177,7 @@ export enum ScheduleResultType {
} }
export const REMOTEMACHINE_TRIAL_COMMAND_FORMAT: string = export const REMOTEMACHINE_TRIAL_COMMAND_FORMAT: string =
`#!/bin/bash `#!/bin/bash
export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} \ export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} \
NNI_TRIAL_SEQ_ID={4} export MULTI_PHASE={5} NNI_TRIAL_SEQ_ID={4} export MULTI_PHASE={5}
cd $NNI_SYS_DIR cd $NNI_SYS_DIR
...@@ -251,7 +188,7 @@ python3 -m nni_trial_tool.trial_keeper --trial_command '{7}' --nnimanager_ip '{8 ...@@ -251,7 +188,7 @@ python3 -m nni_trial_tool.trial_keeper --trial_command '{7}' --nnimanager_ip '{8
echo $? \`date +%s%3N\` >{12}`; echo $? \`date +%s%3N\` >{12}`;
export const HOST_JOB_SHELL_FORMAT: string = export const HOST_JOB_SHELL_FORMAT: string =
`#!/bin/bash `#!/bin/bash
cd {0} cd {0}
echo $$ >{1} echo $$ >{1}
eval {2} >stdout 2>stderr eval {2} >stdout 2>stderr
......
...@@ -7,7 +7,6 @@ import * as assert from 'assert'; ...@@ -7,7 +7,6 @@ import * as assert from 'assert';
import { EventEmitter } from 'events'; import { EventEmitter } from 'events';
import * as fs from 'fs'; import * as fs from 'fs';
import * as path from 'path'; import * as path from 'path';
import { Client } from 'ssh2';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations'; import { String } from 'typescript-string-operations';
import * as component from '../../common/component'; import * as component from '../../common/component';
...@@ -30,22 +29,22 @@ import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; ...@@ -30,22 +29,22 @@ import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { execCopydir, execMkdir, validateCodeDir, getGpuMetricsCollectorBashScriptContent } from '../common/util'; import { execCopydir, execMkdir, validateCodeDir, getGpuMetricsCollectorBashScriptContent } from '../common/util';
import { GPUScheduler } from './gpuScheduler'; import { GPUScheduler } from './gpuScheduler';
import { import {
RemoteCommandResult, REMOTEMACHINE_TRIAL_COMMAND_FORMAT, RemoteMachineMeta, REMOTEMACHINE_TRIAL_COMMAND_FORMAT, RemoteMachineMeta,
RemoteMachineScheduleInfo, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail, RemoteMachineScheduleInfo, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail,
ScheduleResultType, SSHClientManager ScheduleResultType, ExecutorManager
} from './remoteMachineData'; } from './remoteMachineData';
import { RemoteMachineJobRestServer } from './remoteMachineJobRestServer'; import { RemoteMachineJobRestServer } from './remoteMachineJobRestServer';
import { SSHClientUtility } from './sshClientUtility'; import { ShellExecutor } from 'training_service/remote_machine/shellExecutor';
/** /**
* Training Service implementation for Remote Machine (Linux) * Training Service implementation for Remote Machine (Linux)
*/ */
@component.Singleton @component.Singleton
class RemoteMachineTrainingService implements TrainingService { class RemoteMachineTrainingService implements TrainingService {
private readonly machineSSHClientMap: Map<RemoteMachineMeta, SSHClientManager>; //machine ssh client map private readonly machineExecutorManagerMap: Map<RemoteMachineMeta, ExecutorManager>; //machine excutor map
private readonly trialSSHClientMap: Map<string, Client>; //trial ssh client map private readonly trialExecutorMap: Map<string, ShellExecutor>; //trial excutor map
private readonly trialJobsMap: Map<string, RemoteMachineTrialJobDetail>; private readonly trialJobsMap: Map<string, RemoteMachineTrialJobDetail>;
private readonly MAX_TRIAL_NUMBER_PER_SSHCONNECTION: number = 5; // every ssh client has a max trial concurrency number private readonly MAX_TRIAL_NUMBER_PER_EXECUTOR: number = 5; // every excutor has a max trial concurrency number
private readonly expRootDir: string; private readonly expRootDir: string;
private readonly remoteExpRootDir: string; private readonly remoteExpRootDir: string;
private trialConfig: TrialConfig | undefined; private trialConfig: TrialConfig | undefined;
...@@ -67,8 +66,8 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -67,8 +66,8 @@ class RemoteMachineTrainingService implements TrainingService {
this.remoteOS = 'linux'; this.remoteOS = 'linux';
this.metricsEmitter = new EventEmitter(); this.metricsEmitter = new EventEmitter();
this.trialJobsMap = new Map<string, RemoteMachineTrialJobDetail>(); this.trialJobsMap = new Map<string, RemoteMachineTrialJobDetail>();
this.trialSSHClientMap = new Map<string, Client>(); this.trialExecutorMap = new Map<string, ShellExecutor>();
this.machineSSHClientMap = new Map<RemoteMachineMeta, SSHClientManager>(); this.machineExecutorManagerMap = new Map<RemoteMachineMeta, ExecutorManager>();
this.jobQueue = []; this.jobQueue = [];
this.expRootDir = getExperimentRootDir(); this.expRootDir = getExperimentRootDir();
this.remoteExpRootDir = this.getRemoteExperimentRootDir(); this.remoteExpRootDir = this.getRemoteExperimentRootDir();
...@@ -111,38 +110,34 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -111,38 +110,34 @@ class RemoteMachineTrainingService implements TrainingService {
} }
/** /**
* give trial a ssh connection * give trial an executor
* @param trial remote machine trial job detail * @param trial remote machine trial job detail
*/ */
public async allocateSSHClientForTrial(trial: RemoteMachineTrialJobDetail): Promise<void> { public async allocateExecutorForTrial(trial: RemoteMachineTrialJobDetail): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
if (trial.rmMeta === undefined) { if (trial.rmMeta === undefined) {
throw new Error(`rmMeta not set in trial ${trial.id}`); throw new Error(`rmMeta not set in trial ${trial.id}`);
} }
const sshClientManager: SSHClientManager | undefined = this.machineSSHClientMap.get(trial.rmMeta); const executorManager: ExecutorManager | undefined = this.machineExecutorManagerMap.get(trial.rmMeta);
if (sshClientManager === undefined) { if (executorManager === undefined) {
throw new Error(`remoteSSHClient not initialized`); throw new Error(`executorManager not initialized`);
} }
const sshClient: Client = await sshClientManager.getAvailableSSHClient(); const shellExecutor: ShellExecutor = await executorManager.getAvailableExecutor();
this.trialSSHClientMap.set(trial.id, sshClient); this.trialExecutorMap.set(trial.id, shellExecutor);
deferred.resolve();
return deferred.promise;
} }
/** /**
* If a trial is finished, release the connection resource * If a trial is finished, release the connection resource
* @param trial remote machine trial job detail * @param trial remote machine trial job detail
*/ */
public releaseTrialSSHClient(trial: RemoteMachineTrialJobDetail): void { public releaseTrialExecutor(trial: RemoteMachineTrialJobDetail): void {
if (trial.rmMeta === undefined) { if (trial.rmMeta === undefined) {
throw new Error(`rmMeta not set in trial ${trial.id}`); throw new Error(`rmMeta not set in trial ${trial.id}`);
} }
const sshClientManager: SSHClientManager | undefined = this.machineSSHClientMap.get(trial.rmMeta); const executorManager: ExecutorManager | undefined = this.machineExecutorManagerMap.get(trial.rmMeta);
if (sshClientManager === undefined) { if (executorManager === undefined) {
throw new Error(`sshClientManager not initialized`); throw new Error(`executorManager not initialized`);
} }
sshClientManager.releaseConnection(this.trialSSHClientMap.get(trial.id)); executorManager.releaseConnection(this.trialExecutorMap.get(trial.id));
} }
/** /**
...@@ -152,7 +147,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -152,7 +147,7 @@ class RemoteMachineTrainingService implements TrainingService {
const jobs: TrialJobDetail[] = []; const jobs: TrialJobDetail[] = [];
const deferred: Deferred<TrialJobDetail[]> = new Deferred<TrialJobDetail[]>(); const deferred: Deferred<TrialJobDetail[]> = new Deferred<TrialJobDetail[]>();
for (const [key, value] of this.trialJobsMap) { for (const [key,] of this.trialJobsMap) {
jobs.push(await this.getTrialJob(key)); jobs.push(await this.getTrialJob(key));
} }
deferred.resolve(jobs); deferred.resolve(jobs);
...@@ -171,16 +166,16 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -171,16 +166,16 @@ class RemoteMachineTrainingService implements TrainingService {
} }
//TO DO: add another job status, and design new job status change logic //TO DO: add another job status, and design new job status change logic
if (trialJob.status === 'RUNNING' || trialJob.status === 'UNKNOWN') { if (trialJob.status === 'RUNNING' || trialJob.status === 'UNKNOWN') {
// Get ssh client where the job is running // Get executor where the job is running
if (trialJob.rmMeta === undefined) { if (trialJob.rmMeta === undefined) {
throw new Error(`rmMeta not set for submitted job ${trialJobId}`); throw new Error(`rmMeta not set for submitted job ${trialJobId}`);
} }
const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJob.id); const executor: ShellExecutor | undefined = this.trialExecutorMap.get(trialJob.id);
if (sshClient === undefined) { if (executor === undefined) {
throw new Error(`Invalid job id: ${trialJobId}, cannot find ssh client`); throw new Error(`Invalid job id: ${trialJobId}, cannot find executor`);
} }
return this.updateTrialJobStatus(trialJob, sshClient); return this.updateTrialJobStatus(trialJob, executor);
} else { } else {
return trialJob; return trialJob;
} }
...@@ -255,10 +250,8 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -255,10 +250,8 @@ class RemoteMachineTrainingService implements TrainingService {
* @param trialJobId ID of trial job * @param trialJobId ID of trial job
*/ */
public async cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> { public async cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
const trialJob: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId); const trialJob: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJob === undefined) { if (trialJob === undefined) {
deferred.reject();
throw new Error(`trial job id ${trialJobId} not found`); throw new Error(`trial job id ${trialJobId} not found`);
} }
...@@ -268,17 +261,16 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -268,17 +261,16 @@ class RemoteMachineTrainingService implements TrainingService {
this.jobQueue.splice(index, 1); this.jobQueue.splice(index, 1);
} }
// Get ssh client where the job is running // Get executor where the job is running
if (trialJob.rmMeta !== undefined) { if (trialJob.rmMeta !== undefined) {
// If the trial job is already scheduled, check its status and kill the trial process in remote machine // If the trial job is already scheduled, check its status and kill the trial process in remote machine
const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJob.id); const executor: ShellExecutor | undefined = this.trialExecutorMap.get(trialJob.id);
if (sshClient === undefined) { if (executor === undefined) {
deferred.reject(); throw new Error(`Invalid job id ${trialJobId}, cannot find executor`);
throw new Error(`Invalid job id ${trialJobId}, cannot find ssh client`);
} }
if (trialJob.status === 'UNKNOWN') { if (trialJob.status === 'UNKNOWN') {
this.releaseTrialSSHClient(trialJob); this.releaseTrialExecutor(trialJob);
trialJob.status = 'USER_CANCELED'; trialJob.status = 'USER_CANCELED';
return return
} }
...@@ -287,8 +279,8 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -287,8 +279,8 @@ class RemoteMachineTrainingService implements TrainingService {
try { try {
// Mark the toEarlyStop tag here // Mark the toEarlyStop tag here
trialJob.isEarlyStopped = isEarlyStopped; trialJob.isEarlyStopped = isEarlyStopped;
await SSHClientUtility.remoteExeCommand(`pkill -P \`cat ${jobpidPath}\``, sshClient); await executor.killChildProcesses(jobpidPath);
this.releaseTrialSSHClient(trialJob); this.releaseTrialExecutor(trialJob);
} catch (error) { } catch (error) {
// Not handle the error since pkill failed will not impact trial job's current status // Not handle the error since pkill failed will not impact trial job's current status
this.log.error(`remoteTrainingService.cancelTrialJob: ${error.message}`); this.log.error(`remoteTrainingService.cancelTrialJob: ${error.message}`);
...@@ -303,7 +295,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -303,7 +295,7 @@ class RemoteMachineTrainingService implements TrainingService {
/** /**
* Set culster metadata * Set culster metadata
* @param key metadata key * @param key metadata key
* //1. MACHINE_LIST -- create ssh client connect of machine list * //1. MACHINE_LIST -- create executor of machine list
* //2. TRIAL_CONFIG -- trial configuration * //2. TRIAL_CONFIG -- trial configuration
* @param value metadata value * @param value metadata value
*/ */
...@@ -314,7 +306,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -314,7 +306,7 @@ class RemoteMachineTrainingService implements TrainingService {
break; break;
case TrialConfigMetadataKey.MACHINE_LIST: case TrialConfigMetadataKey.MACHINE_LIST:
await this.setupConnections(value); await this.setupConnections(value);
this.gpuScheduler = new GPUScheduler(this.machineSSHClientMap); this.gpuScheduler = new GPUScheduler(this.machineExecutorManagerMap);
break; break;
case TrialConfigMetadataKey.TRIAL_CONFIG: { case TrialConfigMetadataKey.TRIAL_CONFIG: {
const remoteMachineTrailConfig: TrialConfig = <TrialConfig>JSON.parse(value); const remoteMachineTrailConfig: TrialConfig = <TrialConfig>JSON.parse(value);
...@@ -359,10 +351,8 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -359,10 +351,8 @@ class RemoteMachineTrainingService implements TrainingService {
* Get culster metadata * Get culster metadata
* @param key metadata key * @param key metadata key
*/ */
public getClusterMetadata(key: string): Promise<string> { public async getClusterMetadata(key: string): Promise<string> {
const deferred: Deferred<string> = new Deferred<string>(); return "";
return deferred.promise;
} }
/** /**
...@@ -392,14 +382,14 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -392,14 +382,14 @@ class RemoteMachineTrainingService implements TrainingService {
*/ */
private async cleanupConnections(): Promise<void> { private async cleanupConnections(): Promise<void> {
try { try {
for (const [rmMeta, sshClientManager] of this.machineSSHClientMap.entries()) { for (const [rmMeta, executorManager] of this.machineExecutorManagerMap.entries()) {
const jobpidPath: string = unixPathJoin(this.getRemoteScriptsPath(rmMeta.username), 'pid'); const jobpidPath: string = unixPathJoin(this.getRemoteScriptsPath(rmMeta.username), 'pid');
const client: Client | undefined = sshClientManager.getFirstSSHClient(); const executor: ShellExecutor | undefined = executorManager.getFirstExecutor();
if (client !== undefined) { if (executor !== undefined) {
await SSHClientUtility.remoteExeCommand(`pkill -P \`cat ${jobpidPath}\``, client); await executor.killChildProcesses(jobpidPath);
await SSHClientUtility.remoteExeCommand(`rm -rf ${this.getRemoteScriptsPath(rmMeta.username)}`, client); await executor.removeFolder(this.getRemoteScriptsPath(rmMeta.username));
} }
sshClientManager.closeAllSSHClient(); executorManager.closeAllExecutor();
} }
} catch (error) { } catch (error) {
//ignore error, this function is called to cleanup remote connections when experiment is stopping //ignore error, this function is called to cleanup remote connections when experiment is stopping
...@@ -418,10 +408,10 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -418,10 +408,10 @@ class RemoteMachineTrainingService implements TrainingService {
rmMetaList.forEach(async (rmMeta: RemoteMachineMeta) => { rmMetaList.forEach(async (rmMeta: RemoteMachineMeta) => {
rmMeta.occupiedGpuIndexMap = new Map<number, number>(); rmMeta.occupiedGpuIndexMap = new Map<number, number>();
const sshClientManager: SSHClientManager = new SSHClientManager([], this.MAX_TRIAL_NUMBER_PER_SSHCONNECTION, rmMeta); const executorManager: ExecutorManager = new ExecutorManager([], this.MAX_TRIAL_NUMBER_PER_EXECUTOR, rmMeta);
const sshClient: Client = await sshClientManager.getAvailableSSHClient(); const executor: ShellExecutor = await executorManager.getAvailableExecutor();
this.machineSSHClientMap.set(rmMeta, sshClientManager); this.machineExecutorManagerMap.set(rmMeta, executorManager);
await this.initRemoteMachineOnConnected(rmMeta, sshClient); await this.initRemoteMachineOnConnected(rmMeta, executor);
if (++connectedRMNum === rmMetaList.length) { if (++connectedRMNum === rmMetaList.length) {
deferred.resolve(); deferred.resolve();
} }
...@@ -430,26 +420,25 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -430,26 +420,25 @@ class RemoteMachineTrainingService implements TrainingService {
return deferred.promise; return deferred.promise;
} }
private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, conn: Client): Promise<void> { private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, executor: ShellExecutor): Promise<void> {
// Create root working directory after ssh connection is ready // Create root working directory after executor is ready
const nniRootDir: string = unixPathJoin(getRemoteTmpDir(this.remoteOS), 'nni'); const nniRootDir: string = unixPathJoin(getRemoteTmpDir(this.remoteOS), 'nni');
await SSHClientUtility.remoteExeCommand(`mkdir -p ${this.remoteExpRootDir}`, conn); await executor.createFolder(this.remoteExpRootDir);
// the directory to store temp scripts in remote machine // the directory to store temp scripts in remote machine
const remoteGpuScriptCollectorDir: string = this.getRemoteScriptsPath(rmMeta.username); const remoteGpuScriptCollectorDir: string = this.getRemoteScriptsPath(rmMeta.username);
await SSHClientUtility.remoteExeCommand(`(umask 0 ; mkdir -p ${remoteGpuScriptCollectorDir})`, conn); await executor.createFolder(remoteGpuScriptCollectorDir, true);
await SSHClientUtility.remoteExeCommand(`chmod 777 ${nniRootDir} ${nniRootDir}/* ${nniRootDir}/scripts/*`, conn); await executor.allowPermission(false, nniRootDir, `${nniRootDir}/*`, `${nniRootDir}/scripts/*`);
//Begin to execute gpu_metrics_collection scripts //Begin to execute gpu_metrics_collection scripts
const script = getGpuMetricsCollectorBashScriptContent(remoteGpuScriptCollectorDir); const script = getGpuMetricsCollectorBashScriptContent(remoteGpuScriptCollectorDir);
SSHClientUtility.remoteExeCommand(`bash -c '${script}'`, conn, true); executor.executeScript(script, false, true);
const disposable: Rx.IDisposable = this.timer.subscribe( const disposable: Rx.IDisposable = this.timer.subscribe(
async (tick: number) => { async () => {
const cmdresult: RemoteCommandResult = await SSHClientUtility.remoteExeCommand( const cmdresult = await executor.readLastLines(unixPathJoin(remoteGpuScriptCollectorDir, 'gpu_metrics'));
`tail -n 1 ${unixPathJoin(remoteGpuScriptCollectorDir, 'gpu_metrics')}`, conn); if (cmdresult !== "") {
if (cmdresult !== undefined && cmdresult.stdout !== undefined && cmdresult.stdout.length > 0) { rmMeta.gpuSummary = <GPUSummary>JSON.parse(cmdresult);
rmMeta.gpuSummary = <GPUSummary>JSON.parse(cmdresult.stdout);
if (rmMeta.gpuSummary.gpuCount === 0) { if (rmMeta.gpuSummary.gpuCount === 0) {
this.log.warning(`No GPU found on remote machine ${rmMeta.ip}`); this.log.warning(`No GPU found on remote machine ${rmMeta.ip}`);
this.timer.unsubscribe(disposable); this.timer.unsubscribe(disposable);
...@@ -478,7 +467,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -478,7 +467,7 @@ class RemoteMachineTrainingService implements TrainingService {
return deferred.promise; return deferred.promise;
} }
// get an ssh client from scheduler // get an executor from scheduler
const rmScheduleResult: RemoteMachineScheduleResult = this.gpuScheduler.scheduleMachine(this.trialConfig.gpuNum, trialJobDetail); const rmScheduleResult: RemoteMachineScheduleResult = this.gpuScheduler.scheduleMachine(this.trialConfig.gpuNum, trialJobDetail);
if (rmScheduleResult.resultType === ScheduleResultType.REQUIRE_EXCEED_TOTAL) { if (rmScheduleResult.resultType === ScheduleResultType.REQUIRE_EXCEED_TOTAL) {
const errorMessage: string = `Required GPU number ${this.trialConfig.gpuNum} is too large, no machine can meet`; const errorMessage: string = `Required GPU number ${this.trialConfig.gpuNum} is too large, no machine can meet`;
...@@ -492,7 +481,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -492,7 +481,7 @@ class RemoteMachineTrainingService implements TrainingService {
trialJobDetail.rmMeta = rmScheduleInfo.rmMeta; trialJobDetail.rmMeta = rmScheduleInfo.rmMeta;
await this.allocateSSHClientForTrial(trialJobDetail); await this.allocateExecutorForTrial(trialJobDetail);
await this.launchTrialOnScheduledMachine( await this.launchTrialOnScheduledMachine(
trialJobId, trialWorkingFolder, trialJobDetail.form, rmScheduleInfo); trialJobId, trialWorkingFolder, trialJobDetail.form, rmScheduleInfo);
...@@ -518,9 +507,9 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -518,9 +507,9 @@ class RemoteMachineTrainingService implements TrainingService {
throw new Error('trial config is not initialized'); throw new Error('trial config is not initialized');
} }
const cudaVisibleDevice: string = rmScheduleInfo.cudaVisibleDevice; const cudaVisibleDevice: string = rmScheduleInfo.cudaVisibleDevice;
const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJobId); const executor: ShellExecutor | undefined = this.trialExecutorMap.get(trialJobId);
if (sshClient === undefined) { if (executor === undefined) {
assert(false, 'sshClient is undefined.'); assert(false, 'ShellExecutor is undefined.');
// for lint // for lint
return; return;
...@@ -532,8 +521,8 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -532,8 +521,8 @@ class RemoteMachineTrainingService implements TrainingService {
const trialLocalTempFolder: string = path.join(this.expRootDir, 'trials-local', trialJobId); const trialLocalTempFolder: string = path.join(this.expRootDir, 'trials-local', trialJobId);
await SSHClientUtility.remoteExeCommand(`mkdir -p ${trialWorkingFolder}`, sshClient); await executor.createFolder(trialWorkingFolder);
await SSHClientUtility.remoteExeCommand(`mkdir -p ${unixPathJoin(trialWorkingFolder, '.nni')}`, sshClient); await executor.createFolder(unixPathJoin(trialWorkingFolder, '.nni'));
// RemoteMachineRunShellFormat is the run shell format string, // RemoteMachineRunShellFormat is the run shell format string,
// See definition in remoteMachineData.ts // See definition in remoteMachineData.ts
...@@ -586,13 +575,13 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -586,13 +575,13 @@ class RemoteMachineTrainingService implements TrainingService {
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run.sh'), runScriptTrialContent, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run.sh'), runScriptTrialContent, { encoding: 'utf8' });
await this.writeParameterFile(trialJobId, form.hyperParameters); await this.writeParameterFile(trialJobId, form.hyperParameters);
// Copy files in codeDir to remote working directory // Copy files in codeDir to remote working directory
await SSHClientUtility.copyDirectoryToRemote(trialLocalTempFolder, trialWorkingFolder, sshClient, this.remoteOS); await executor.copyDirectoryToRemote(trialLocalTempFolder, trialWorkingFolder, this.remoteOS);
// Execute command in remote machine // Execute command in remote machine
SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(trialWorkingFolder, 'run.sh')}`, sshClient, true); executor.executeScript(unixPathJoin(trialWorkingFolder, 'run.sh'), true, true);
} }
private getRmMetaByHost(host: string): RemoteMachineMeta { private getRmMetaByHost(host: string): RemoteMachineMeta {
for (const [rmMeta, client] of this.machineSSHClientMap.entries()) { for (const rmMeta of this.machineExecutorManagerMap.keys()) {
if (rmMeta.ip === host) { if (rmMeta.ip === host) {
return rmMeta; return rmMeta;
} }
...@@ -600,18 +589,18 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -600,18 +589,18 @@ class RemoteMachineTrainingService implements TrainingService {
throw new Error(`Host not found: ${host}`); throw new Error(`Host not found: ${host}`);
} }
private async updateTrialJobStatus(trialJob: RemoteMachineTrialJobDetail, sshClient: Client): Promise<TrialJobDetail> { private async updateTrialJobStatus(trialJob: RemoteMachineTrialJobDetail, executor: ShellExecutor): Promise<TrialJobDetail> {
const deferred: Deferred<TrialJobDetail> = new Deferred<TrialJobDetail>(); const deferred: Deferred<TrialJobDetail> = new Deferred<TrialJobDetail>();
const jobpidPath: string = this.getJobPidPath(trialJob.id); const jobpidPath: string = this.getJobPidPath(trialJob.id);
const trialReturnCodeFilePath: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJob.id, '.nni', 'code'); const trialReturnCodeFilePath: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJob.id, '.nni', 'code');
/* eslint-disable require-atomic-updates */ /* eslint-disable require-atomic-updates */
try { try {
const killResult: number = (await SSHClientUtility.remoteExeCommand(`kill -0 \`cat ${jobpidPath}\``, sshClient)).exitCode; const isAlive = await executor.isProcessAlive(jobpidPath);
// if the process of jobpid is not alive any more // if the process of jobpid is not alive any more
if (killResult !== 0) { if (!isAlive) {
const trailReturnCode: string = await SSHClientUtility.getRemoteFileContent(trialReturnCodeFilePath, sshClient); const trialReturnCode: string = await executor.getRemoteFileContent(trialReturnCodeFilePath);
this.log.debug(`trailjob ${trialJob.id} return code: ${trailReturnCode}`); this.log.debug(`trailjob ${trialJob.id} return code: ${trialReturnCode}`);
const match: RegExpMatchArray | null = trailReturnCode.trim() const match: RegExpMatchArray | null = trialReturnCode.trim()
.match(/^(\d+)\s+(\d+)$/); .match(/^(\d+)\s+(\d+)$/);
if (match !== null) { if (match !== null) {
const { 1: code, 2: timestamp } = match; const { 1: code, 2: timestamp } = match;
...@@ -627,7 +616,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -627,7 +616,7 @@ class RemoteMachineTrainingService implements TrainingService {
} }
} }
trialJob.endTime = parseInt(timestamp, 10); trialJob.endTime = parseInt(timestamp, 10);
this.releaseTrialSSHClient(trialJob); this.releaseTrialExecutor(trialJob);
} }
this.log.debug(`trailJob status update: ${trialJob.id}, ${trialJob.status}`); this.log.debug(`trailJob status update: ${trialJob.id}, ${trialJob.status}`);
} }
...@@ -671,9 +660,9 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -671,9 +660,9 @@ class RemoteMachineTrainingService implements TrainingService {
} }
private async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters): Promise<void> { private async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters): Promise<void> {
const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJobId); const executor: ShellExecutor | undefined = this.trialExecutorMap.get(trialJobId);
if (sshClient === undefined) { if (executor === undefined) {
throw new Error('sshClient is undefined.'); throw new Error('ShellExecutor is undefined.');
} }
const trialWorkingFolder: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJobId); const trialWorkingFolder: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJobId);
...@@ -683,7 +672,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -683,7 +672,7 @@ class RemoteMachineTrainingService implements TrainingService {
const localFilepath: string = path.join(trialLocalTempFolder, fileName); const localFilepath: string = path.join(trialLocalTempFolder, fileName);
await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' }); await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' });
await SSHClientUtility.copyFileToRemote(localFilepath, unixPathJoin(trialWorkingFolder, fileName), sshClient); await executor.copyFileToRemote(localFilepath, unixPathJoin(trialWorkingFolder, fileName));
} }
} }
......
...@@ -6,33 +6,159 @@ ...@@ -6,33 +6,159 @@
import * as assert from 'assert'; import * as assert from 'assert';
import * as os from 'os'; import * as os from 'os';
import * as path from 'path'; import * as path from 'path';
import { Client, ClientChannel, SFTPWrapper } from 'ssh2'; import * as fs from 'fs';
import { Client, ClientChannel, SFTPWrapper, ConnectConfig } from 'ssh2';
import { Deferred } from "ts-deferred";
import { RemoteCommandResult, RemoteMachineMeta } from "./remoteMachineData";
import * as stream from 'stream'; import * as stream from 'stream';
import { Deferred } from 'ts-deferred'; import { OsCommands } from "./osCommands";
import { NNIError, NNIErrorNames } from '../../common/errors'; import { LinuxCommands } from "./extends/linuxCommands";
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { getRemoteTmpDir, uniqueString, unixPathJoin } from '../../common/utils'; import { NNIError, NNIErrorNames } from '../../common/errors';
import { execRemove, tarAdd } from '../common/util'; import { execRemove, tarAdd } from '../common/util';
import { RemoteCommandResult } from './remoteMachineData'; import { getRemoteTmpDir, uniqueString, unixPathJoin } from '../../common/utils';
/** class ShellExecutor {
* private sshClient: Client = new Client();
* Utility for frequent operations towards SSH client private osCommands: OsCommands | undefined;
* private usedConnectionNumber: number = 0; //count the connection number of every client
*/
export namespace SSHClientUtility { protected pathSpliter: string = '/';
protected multiplePathSpliter: RegExp = new RegExp(`\\${this.pathSpliter}{2,}`);
public async initialize(rmMeta: RemoteMachineMeta): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
const connectConfig: ConnectConfig = {
host: rmMeta.ip,
port: rmMeta.port,
username: rmMeta.username,
tryKeyboard: true
};
if (rmMeta.passwd !== undefined) {
connectConfig.password = rmMeta.passwd;
} else if (rmMeta.sshKeyPath !== undefined) {
if (!fs.existsSync(rmMeta.sshKeyPath)) {
//SSh key path is not a valid file, reject
deferred.reject(new Error(`${rmMeta.sshKeyPath} does not exist.`));
}
const privateKey: string = fs.readFileSync(rmMeta.sshKeyPath, 'utf8');
connectConfig.privateKey = privateKey;
connectConfig.passphrase = rmMeta.passphrase;
} else {
deferred.reject(new Error(`No valid passwd or sshKeyPath is configed.`));
}
this.sshClient.on('ready', async () => {
// check OS type: windows or else
const result = await this.execute("ver");
if (result.exitCode == 0 && result.stdout.search("Windows") > -1) {
// not implement Windows commands yet.
throw new Error("not implement Windows commands yet.");
} else {
this.osCommands = new LinuxCommands();
}
deferred.resolve();
}).on('error', (err: Error) => {
// SSH connection error, reject with error message
deferred.reject(new Error(err.message));
}).on("keyboard-interactive", (name, instructions, lang, prompts, finish) => {
finish([rmMeta.passwd]);
}).connect(connectConfig);
return deferred.promise;
}
public close(): void {
this.sshClient.end();
}
public get getUsedConnectionNumber(): number {
return this.usedConnectionNumber;
}
public addUsedConnectionNumber(): void {
this.usedConnectionNumber += 1;
}
public minusUsedConnectionNumber(): void {
this.usedConnectionNumber -= 1;
}
public async createFolder(folderName: string, sharedFolder: boolean = false): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.createFolder(folderName, sharedFolder);
const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0;
return result;
}
public async allowPermission(isRecursive: boolean = false, ...folders: string[]): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.allowPermission(isRecursive, ...folders);
const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0;
return result;
}
public async removeFolder(folderName: string, isRecursive: boolean = false, isForce: boolean = true): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.removeFolder(folderName, isRecursive, isForce);
const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0;
return result;
}
public async removeFiles(folderOrFileName: string, filePattern: string = ""): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.removeFiles(folderOrFileName, filePattern);
const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0;
return result;
}
public async readLastLines(fileName: string, lineCount: number = 1): Promise<string> {
const commandText = this.osCommands && this.osCommands.readLastLines(fileName, lineCount);
const commandResult = await this.execute(commandText);
let result: string = "";
if (commandResult !== undefined && commandResult.stdout !== undefined && commandResult.stdout.length > 0) {
result = commandResult.stdout;
}
return result;
}
public async isProcessAlive(pidFileName: string): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.isProcessAliveCommand(pidFileName);
const commandResult = await this.execute(commandText);
const result = this.osCommands && this.osCommands.isProcessAliveProcessOutput(commandResult);
return result !== undefined ? result : false;
}
public async killChildProcesses(pidFileName: string): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.killChildProcesses(pidFileName);
const commandResult = await this.execute(commandText);
return commandResult.exitCode == 0;
}
public async extractFile(tarFileName: string, targetFolder: string): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.extractFile(tarFileName, targetFolder);
const commandResult = await this.execute(commandText);
return commandResult.exitCode == 0;
}
public async executeScript(script: string, isFile: boolean, isInteractive: boolean = false): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.executeScript(script, isFile);
const commandResult = await this.execute(commandText, undefined, isInteractive);
return commandResult.exitCode == 0;
}
/** /**
* Copy local file to remote path * Copy local file to remote path
* @param localFilePath the path of local file * @param localFilePath the path of local file
* @param remoteFilePath the target path in remote machine * @param remoteFilePath the target path in remote machine
* @param sshClient SSH Client
*/ */
export function copyFileToRemote(localFilePath: string, remoteFilePath: string, sshClient: Client): Promise<boolean> { public async copyFileToRemote(localFilePath: string, remoteFilePath: string): Promise<boolean> {
const log: Logger = getLogger(); const log: Logger = getLogger();
log.debug(`copyFileToRemote: localFilePath: ${localFilePath}, remoteFilePath: ${remoteFilePath}`); log.debug(`copyFileToRemote: localFilePath: ${localFilePath}, remoteFilePath: ${remoteFilePath}`);
assert(sshClient !== undefined);
const deferred: Deferred<boolean> = new Deferred<boolean>(); const deferred: Deferred<boolean> = new Deferred<boolean>();
sshClient.sftp((err: Error, sftp: SFTPWrapper) => { this.sshClient.sftp((err: Error, sftp: SFTPWrapper) => {
if (err !== undefined && err !== null) { if (err !== undefined && err !== null) {
log.error(`copyFileToRemote: ${err.message}, ${localFilePath}, ${remoteFilePath}`); log.error(`copyFileToRemote: ${err.message}, ${localFilePath}, ${remoteFilePath}`);
deferred.reject(err); deferred.reject(err);
...@@ -53,66 +179,13 @@ export namespace SSHClientUtility { ...@@ -53,66 +179,13 @@ export namespace SSHClientUtility {
return deferred.promise; return deferred.promise;
} }
/**
* Execute command on remote machine
* @param command the command to execute remotely
* @param client SSH Client
*/
export function remoteExeCommand(command: string, client: Client, useShell: boolean = false): Promise<RemoteCommandResult> {
const log: Logger = getLogger();
log.debug(`remoteExeCommand: command: [${command}]`);
const deferred: Deferred<RemoteCommandResult> = new Deferred<RemoteCommandResult>();
let stdout: string = '';
let stderr: string = '';
let exitCode: number;
const callback = (err: Error, channel: ClientChannel): void => {
if (err !== undefined && err !== null) {
log.error(`remoteExeCommand: ${err.message}`);
deferred.reject(err);
return;
}
channel.on('data', (data: any) => {
stdout += data;
});
channel.on('exit', (code: any) => {
exitCode = <number>code;
log.debug(`remoteExeCommand exit(${exitCode})\nstdout: ${stdout}\nstderr: ${stderr}`);
deferred.resolve({
stdout: stdout,
stderr: stderr,
exitCode: exitCode
});
});
channel.stderr.on('data', function (data) {
stderr += data;
});
if (useShell) {
channel.stdin.write(`${command}\n`);
channel.end("exit\n");
}
return;
};
if (useShell) {
client.shell(callback);
} else {
client.exec(command, callback);
}
return deferred.promise;
}
/** /**
* Copy files and directories in local directory recursively to remote directory * Copy files and directories in local directory recursively to remote directory
* @param localDirectory local diretory * @param localDirectory local diretory
* @param remoteDirectory remote directory * @param remoteDirectory remote directory
* @param sshClient SSH client * @param sshClient SSH client
*/ */
export async function copyDirectoryToRemote(localDirectory: string, remoteDirectory: string, sshClient: Client, remoteOS: string): Promise<void> { public async copyDirectoryToRemote(localDirectory: string, remoteDirectory: string, remoteOS: string): Promise<void> {
const tmpSuffix: string = uniqueString(5); const tmpSuffix: string = uniqueString(5);
const localTarPath: string = path.join(os.tmpdir(), `nni_tmp_local_${tmpSuffix}.tar.gz`); const localTarPath: string = path.join(os.tmpdir(), `nni_tmp_local_${tmpSuffix}.tar.gz`);
const remoteTarPath: string = unixPathJoin(getRemoteTmpDir(remoteOS), `nni_tmp_remote_${tmpSuffix}.tar.gz`); const remoteTarPath: string = unixPathJoin(getRemoteTmpDir(remoteOS), `nni_tmp_remote_${tmpSuffix}.tar.gz`);
...@@ -120,16 +193,16 @@ export namespace SSHClientUtility { ...@@ -120,16 +193,16 @@ export namespace SSHClientUtility {
// Compress files in local directory to experiment root directory // Compress files in local directory to experiment root directory
await tarAdd(localTarPath, localDirectory); await tarAdd(localTarPath, localDirectory);
// Copy the compressed file to remoteDirectory and delete it // Copy the compressed file to remoteDirectory and delete it
await copyFileToRemote(localTarPath, remoteTarPath, sshClient); await this.copyFileToRemote(localTarPath, remoteTarPath);
await execRemove(localTarPath); await execRemove(localTarPath);
// Decompress the remote compressed file in and delete it // Decompress the remote compressed file in and delete it
await remoteExeCommand(`tar -oxzf ${remoteTarPath} -C ${remoteDirectory}`, sshClient); await this.extractFile(remoteTarPath, remoteDirectory);
await remoteExeCommand(`rm ${remoteTarPath}`, sshClient); await this.removeFiles(remoteTarPath);
} }
export function getRemoteFileContent(filePath: string, sshClient: Client): Promise<string> { public async getRemoteFileContent(filePath: string): Promise<string> {
const deferred: Deferred<string> = new Deferred<string>(); const deferred: Deferred<string> = new Deferred<string>();
sshClient.sftp((err: Error, sftp: SFTPWrapper) => { this.sshClient.sftp((err: Error, sftp: SFTPWrapper) => {
if (err !== undefined && err !== null) { if (err !== undefined && err !== null) {
getLogger() getLogger()
.error(`getRemoteFileContent: ${err.message}`); .error(`getRemoteFileContent: ${err.message}`);
...@@ -163,4 +236,59 @@ export namespace SSHClientUtility { ...@@ -163,4 +236,59 @@ export namespace SSHClientUtility {
return deferred.promise; return deferred.promise;
} }
private async execute(command: string | undefined, processOutput: ((input: RemoteCommandResult) => RemoteCommandResult) | undefined = undefined, useShell: boolean = false): Promise<RemoteCommandResult> {
const log: Logger = getLogger();
log.debug(`remoteExeCommand: command: [${command}]`);
const deferred: Deferred<RemoteCommandResult> = new Deferred<RemoteCommandResult>();
let stdout: string = '';
let stderr: string = '';
let exitCode: number;
const callback = (err: Error, channel: ClientChannel): void => {
if (err !== undefined && err !== null) {
log.error(`remoteExeCommand: ${err.message}`);
deferred.reject(err);
return;
}
channel.on('data', (data: any) => {
stdout += data;
});
channel.on('exit', (code: any) => {
exitCode = <number>code;
log.debug(`remoteExeCommand exit(${exitCode})\nstdout: ${stdout}\nstderr: ${stderr}`);
let result = {
stdout: stdout,
stderr: stderr,
exitCode: exitCode
};
if (processOutput != undefined) {
result = processOutput(result);
}
deferred.resolve(result);
});
channel.stderr.on('data', function (data) {
stderr += data;
});
if (useShell) {
channel.stdin.write(`${command}\n`);
channel.end("exit\n");
}
return;
};
if (useShell) {
this.sshClient.shell(callback);
} else {
this.sshClient.exec(command !== undefined ? command : "", callback);
}
return deferred.promise;
}
} }
export { ShellExecutor };
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as chai from 'chai';
import * as chaiAsPromised from 'chai-as-promised';
import * as component from '../../../common/component';
import { cleanupUnitTest, prepareUnitTest } from '../../../common/utils';
import { LinuxCommands } from '../extends/linuxCommands';
// import { TrialConfigMetadataKey } from '../trialConfigMetadataKey';
describe('Unit Test for linuxCommands', () => {
let linuxCommands: LinuxCommands
before(() => {
chai.should();
chai.use(chaiAsPromised);
prepareUnitTest();
});
after(() => {
cleanupUnitTest();
});
beforeEach(() => {
linuxCommands = component.get(LinuxCommands);
});
afterEach(() => {
});
it('joinPath', async () => {
chai.expect(linuxCommands.joinPath("/root/", "/first")).to.equal("/root/first");
chai.expect(linuxCommands.joinPath("/root", "first")).to.equal("/root/first");
chai.expect(linuxCommands.joinPath("/root/", "first")).to.equal("/root/first");
chai.expect(linuxCommands.joinPath("root/", "first")).to.equal("root/first");
chai.expect(linuxCommands.joinPath("root/")).to.equal("root/");
chai.expect(linuxCommands.joinPath("root")).to.equal("root");
chai.expect(linuxCommands.joinPath("./root")).to.equal("./root");
chai.expect(linuxCommands.joinPath("")).to.equal(".");
chai.expect(linuxCommands.joinPath("..")).to.equal("..");
})
it('createFolder', async () => {
chai.expect(linuxCommands.createFolder("test")).to.equal("mkdir -p 'test'");
chai.expect(linuxCommands.createFolder("test", true)).to.equal("umask 0; mkdir -p 'test'");
})
it('allowPermission', async () => {
chai.expect(linuxCommands.allowPermission(true, "test", "test1")).to.equal("chmod 777 -R 'test' 'test1'");
chai.expect(linuxCommands.allowPermission(false, "test")).to.equal("chmod 777 'test'");
})
it('removeFolder', async () => {
chai.expect(linuxCommands.removeFolder("test")).to.equal("rm -df 'test'");
chai.expect(linuxCommands.removeFolder("test", true)).to.equal("rm -rf 'test'");
chai.expect(linuxCommands.removeFolder("test", true, false)).to.equal("rm -r 'test'");
chai.expect(linuxCommands.removeFolder("test", false, false)).to.equal("rm 'test'");
})
it('removeFiles', async () => {
chai.expect(linuxCommands.removeFiles("test", "*.sh")).to.equal("rm 'test/*.sh'");
chai.expect(linuxCommands.removeFiles("test", "")).to.equal("rm 'test'");
})
it('readLastLines', async () => {
chai.expect(linuxCommands.readLastLines("test", 3)).to.equal("tail -n 3 'test'");
})
it('isProcessAlive', async () => {
chai.expect(linuxCommands.isProcessAliveCommand("test")).to.equal("kill -0 `cat 'test'`");
chai.expect(linuxCommands.isProcessAliveProcessOutput(
{
exitCode: 0,
stdout: "",
stderr: ""
}
)).to.equal(true);
chai.expect(linuxCommands.isProcessAliveProcessOutput(
{
exitCode: 10,
stdout: "",
stderr: ""
}
)).to.equal(false);
})
it('killChildProcesses', async () => {
chai.expect(linuxCommands.killChildProcesses("test")).to.equal("pkill -P `cat 'test'`");
})
it('extractFile', async () => {
chai.expect(linuxCommands.extractFile("test.tar", "testfolder")).to.equal("tar -oxzf 'test.tar' -C 'testfolder'");
})
it('executeScript', async () => {
chai.expect(linuxCommands.executeScript("test.sh", true)).to.equal("bash 'test.sh'");
chai.expect(linuxCommands.executeScript("test script'\"", false)).to.equal(`bash -c \"test script'\\""`);
})
});
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as cpp from 'child-process-promise';
import * as fs from 'fs';
import * as chai from 'chai';
import * as chaiAsPromised from 'chai-as-promised';
import { Client } from 'ssh2';
import { ShellExecutor } from '../shellExecutor';
import { prepareUnitTest, cleanupUnitTest } from '../../../common/utils';
const LOCALFILE: string = '/tmp/localSshclientUTData';
const REMOTEFILE: string = '/tmp/remoteSshclientUTData';
const REMOTEFOLDER: string = '/tmp/remoteSshclientUTFolder';
async function copyFile(executor: ShellExecutor): Promise<void> {
await executor.copyFileToRemote(LOCALFILE, REMOTEFILE);
}
async function copyFileToRemoteLoop(executor: ShellExecutor): Promise<void> {
for (let i: number = 0; i < 10; i++) {
// console.log(i);
await executor.copyFileToRemote(LOCALFILE, REMOTEFILE);
}
}
async function getRemoteFileContentLoop(executor: ShellExecutor): Promise<void> {
for (let i: number = 0; i < 10; i++) {
// console.log(i);
await executor.getRemoteFileContent(REMOTEFILE);
}
}
describe('ShellExecutor test', () => {
let skip: boolean = false;
let rmMeta: any;
try {
rmMeta = JSON.parse(fs.readFileSync('../../.vscode/rminfo.json', 'utf8'));
console.log(rmMeta);
} catch (err) {
console.log(`Please configure rminfo.json to enable remote machine test.${err}`);
skip = true;
}
before(async () => {
chai.should();
chai.use(chaiAsPromised);
await cpp.exec(`echo '1234' > ${LOCALFILE}`);
prepareUnitTest();
});
after(() => {
cleanupUnitTest();
fs.unlinkSync(LOCALFILE);
});
it('Test mkdir', async () => {
if (skip) {
return;
}
const shellExecutor: ShellExecutor = new ShellExecutor();
await shellExecutor.initialize(rmMeta);
let result = await shellExecutor.createFolder(REMOTEFOLDER, false);
chai.expect(result).eq(true);
result = await shellExecutor.removeFolder(REMOTEFOLDER);
chai.expect(result).eq(true);
});
it('Test ShellExecutor', async () => {
if (skip) {
return;
}
const shellExecutor: ShellExecutor = new ShellExecutor();
await shellExecutor.initialize(rmMeta);
await copyFile(shellExecutor);
await Promise.all([
copyFileToRemoteLoop(shellExecutor),
copyFileToRemoteLoop(shellExecutor),
copyFileToRemoteLoop(shellExecutor),
getRemoteFileContentLoop(shellExecutor)
]);
});
});
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as cpp from 'child-process-promise';
import * as fs from 'fs';
import { Client } from 'ssh2';
import { Deferred } from 'ts-deferred';
import { SSHClientUtility } from '../remote_machine/sshClientUtility';
const LOCALFILE: string = '/tmp/sshclientUTData';
const REMOTEFILE: string = '/tmp/sshclientUTData';
async function copyFile(conn: Client): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
conn.sftp((err, sftp) => {
if (err) {
deferred.reject(err);
return;
}
sftp.fastPut(
LOCALFILE,
REMOTEFILE, (fastPutErr: Error) => {
sftp.end();
if (fastPutErr) {
deferred.reject(fastPutErr);
} else {
deferred.resolve();
}
}
);
});
return deferred.promise;
}
async function copyFileToRemoteLoop(conn: Client): Promise<void> {
for (let i: number = 0; i < 500; i++) {
console.log(i);
await SSHClientUtility.copyFileToRemote(LOCALFILE, REMOTEFILE, conn);
}
}
async function remoteExeCommandLoop(conn: Client): Promise<void> {
for (let i: number = 0; i < 500; i++) {
console.log(i);
await SSHClientUtility.remoteExeCommand('ls', conn);
}
}
async function getRemoteFileContentLoop(conn: Client): Promise<void> {
for (let i: number = 0; i < 500; i++) {
console.log(i);
await SSHClientUtility.getRemoteFileContent(REMOTEFILE, conn);
}
}
describe('sshClientUtility test', () => {
let skip: boolean = true;
let rmMeta: any;
try {
rmMeta = JSON.parse(fs.readFileSync('../../.vscode/rminfo.json', 'utf8'));
} catch (err) {
skip = true;
}
before(async () => {
await cpp.exec(`echo '1234' > ${LOCALFILE}`);
});
after(() => {
fs.unlinkSync(LOCALFILE);
});
it('Test SSHClientUtility', (done) => {
if (skip) {
done();
return;
}
const conn: Client = new Client();
conn.on('ready', async () => {
await copyFile(conn);
await Promise.all([
copyFileToRemoteLoop(conn),
copyFileToRemoteLoop(conn),
copyFileToRemoteLoop(conn),
remoteExeCommandLoop(conn),
getRemoteFileContentLoop(conn)
]);
done();
}).connect(rmMeta);
});
});
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment