Unverified Commit cf6a1de6 authored by Chi Song's avatar Chi Song Committed by GitHub
Browse files

Refactor: remote machine isolate OS commands phase 1 (#2376)

To support Windows node in remote mode, this PR adds a layer of commands (osCommands) to deal difference between Windows and Unix-like OS. To share code, ShellExecutor is added to enrich original SshClient class.

I will implement windows version commands in next phase.

This pattern can be expanded to Local or other platform in future, so I moved related code to common folder for sharing.
parent edd3f8ac
...@@ -16,7 +16,7 @@ git clone https://github.com/Microsoft/nni.git ...@@ -16,7 +16,7 @@ git clone https://github.com/Microsoft/nni.git
to clone the source code to clone the source code
### 2. Prepare the debug environment and install dependencies** ### 2. Prepare the debug environment and install dependencies
Change directory to the source code folder, then run the command Change directory to the source code folder, then run the command
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import { OsCommands } from "../osCommands";
import { RemoteCommandResult } from "../remoteMachineData";
class LinuxCommands extends OsCommands {
public createFolder(folderName: string, sharedFolder: boolean = false): string {
let command;
if (sharedFolder) {
command = `umask 0; mkdir -p '${folderName}'`;
} else {
command = `mkdir -p '${folderName}'`;
}
return command;
}
public allowPermission(isRecursive: boolean = false, ...folders: string[]): string {
const folderString = folders.join("' '");
let command;
if (isRecursive) {
command = `chmod 777 -R '${folderString}'`;
} else {
command = `chmod 777 '${folderString}'`;
}
return command;
}
public removeFolder(folderName: string, isRecursive: boolean = false, isForce: boolean = true): string {
let flags = '';
if (isForce || isRecursive) {
flags = `-${isRecursive ? 'r' : 'd'}${isForce ? 'f' : ''} `;
}
const command = `rm ${flags}'${folderName}'`;
return command;
}
public removeFiles(folderName: string, filePattern: string): string {
const files = this.joinPath(folderName, filePattern);
const command = `rm '${files}'`;
return command;
}
public readLastLines(fileName: string, lineCount: number = 1): string {
const command = `tail -n ${lineCount} '${fileName}'`;
return command;
}
public isProcessAliveCommand(pidFileName: string): string {
const command = `kill -0 \`cat '${pidFileName}'\``;
return command;
}
public isProcessAliveProcessOutput(commandResult: RemoteCommandResult): boolean {
let result = true;
if (commandResult.exitCode !== 0) {
result = false;
}
return result;
}
public killChildProcesses(pidFileName: string): string {
const command = `pkill -P \`cat '${pidFileName}'\``;
return command;
}
public extractFile(tarFileName: string, targetFolder: string): string {
const command = `tar -oxzf '${tarFileName}' -C '${targetFolder}'`;
return command;
}
public executeScript(script: string, isFile: boolean): string {
let command: string;
if (isFile) {
command = `bash '${script}'`;
} else {
script = script.replace('"', '\\"');
command = `bash -c "${script}"`;
}
return command;
}
}
export { LinuxCommands };
...@@ -8,7 +8,7 @@ import { getLogger, Logger } from '../../common/log'; ...@@ -8,7 +8,7 @@ import { getLogger, Logger } from '../../common/log';
import { randomSelect } from '../../common/utils'; import { randomSelect } from '../../common/utils';
import { GPUInfo } from '../common/gpuData'; import { GPUInfo } from '../common/gpuData';
import { import {
parseGpuIndices, RemoteMachineMeta, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail, ScheduleResultType, SSHClientManager parseGpuIndices, RemoteMachineMeta, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail, ScheduleResultType, ExecutorManager
} from './remoteMachineData'; } from './remoteMachineData';
type SCHEDULE_POLICY_NAME = 'random' | 'round-robin'; type SCHEDULE_POLICY_NAME = 'random' | 'round-robin';
...@@ -18,7 +18,7 @@ type SCHEDULE_POLICY_NAME = 'random' | 'round-robin'; ...@@ -18,7 +18,7 @@ type SCHEDULE_POLICY_NAME = 'random' | 'round-robin';
*/ */
export class GPUScheduler { export class GPUScheduler {
private readonly machineSSHClientMap: Map<RemoteMachineMeta, SSHClientManager>; private readonly machineExecutorMap: Map<RemoteMachineMeta, ExecutorManager>;
private readonly log: Logger = getLogger(); private readonly log: Logger = getLogger();
private readonly policyName: SCHEDULE_POLICY_NAME = 'round-robin'; private readonly policyName: SCHEDULE_POLICY_NAME = 'round-robin';
private roundRobinIndex: number = 0; private roundRobinIndex: number = 0;
...@@ -26,12 +26,12 @@ export class GPUScheduler { ...@@ -26,12 +26,12 @@ export class GPUScheduler {
/** /**
* Constructor * Constructor
* @param machineSSHClientMap map from remote machine to sshClient * @param machineExecutorMap map from remote machine to executor
*/ */
constructor(machineSSHClientMap: Map<RemoteMachineMeta, SSHClientManager>) { constructor(machineExecutorMap: Map<RemoteMachineMeta, ExecutorManager>) {
assert(machineSSHClientMap.size > 0); assert(machineExecutorMap.size > 0);
this.machineSSHClientMap = machineSSHClientMap; this.machineExecutorMap = machineExecutorMap;
this.configuredRMs = Array.from(machineSSHClientMap.keys()); this.configuredRMs = Array.from(machineExecutorMap.keys());
} }
/** /**
...@@ -43,7 +43,7 @@ export class GPUScheduler { ...@@ -43,7 +43,7 @@ export class GPUScheduler {
requiredGPUNum = 0; requiredGPUNum = 0;
} }
assert(requiredGPUNum >= 0); assert(requiredGPUNum >= 0);
const allRMs: RemoteMachineMeta[] = Array.from(this.machineSSHClientMap.keys()); const allRMs: RemoteMachineMeta[] = Array.from(this.machineExecutorMap.keys());
assert(allRMs.length > 0); assert(allRMs.length > 0);
// Step 1: Check if required GPU number not exceeds the total GPU number in all machines // Step 1: Check if required GPU number not exceeds the total GPU number in all machines
...@@ -135,7 +135,7 @@ export class GPUScheduler { ...@@ -135,7 +135,7 @@ export class GPUScheduler {
*/ */
private gpuResourceDetection(): Map<RemoteMachineMeta, GPUInfo[]> { private gpuResourceDetection(): Map<RemoteMachineMeta, GPUInfo[]> {
const totalResourceMap: Map<RemoteMachineMeta, GPUInfo[]> = new Map<RemoteMachineMeta, GPUInfo[]>(); const totalResourceMap: Map<RemoteMachineMeta, GPUInfo[]> = new Map<RemoteMachineMeta, GPUInfo[]>();
this.machineSSHClientMap.forEach((sshClientManager: SSHClientManager, rmMeta: RemoteMachineMeta) => { this.machineExecutorMap.forEach((executorManager: ExecutorManager, rmMeta: RemoteMachineMeta) => {
// Assgin totoal GPU count as init available GPU number // Assgin totoal GPU count as init available GPU number
if (rmMeta.gpuSummary !== undefined) { if (rmMeta.gpuSummary !== undefined) {
const availableGPUs: GPUInfo[] = []; const availableGPUs: GPUInfo[] = [];
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import { RemoteCommandResult } from "./remoteMachineData";
abstract class OsCommands {
protected pathSpliter: string = '/';
protected multiplePathSpliter: RegExp = new RegExp(`\\${this.pathSpliter}{2,}`);
public abstract createFolder(folderName: string, sharedFolder: boolean): string;
public abstract allowPermission(isRecursive: boolean, ...folders: string[]): string;
public abstract removeFolder(folderName: string, isRecursive: boolean, isForce: boolean): string;
public abstract removeFiles(folderOrFileName: string, filePattern: string): string;
public abstract readLastLines(fileName: string, lineCount: number): string;
public abstract isProcessAliveCommand(pidFileName: string): string;
public abstract isProcessAliveProcessOutput(result: RemoteCommandResult): boolean;
public abstract killChildProcesses(pidFileName: string): string;
public abstract extractFile(tarFileName: string, targetFolder: string): string;
public abstract executeScript(script: string, isFile: boolean): string;
public joinPath(...paths: string[]): string {
let dir: string = paths.filter((path: any) => path !== '').join(this.pathSpliter);
if (dir === '') {
dir = '.';
} else {
dir = dir.replace(this.multiplePathSpliter, this.pathSpliter);
}
return dir;
}
}
export { OsCommands };
...@@ -3,11 +3,9 @@ ...@@ -3,11 +3,9 @@
'use strict'; 'use strict';
import * as fs from 'fs'; import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { Client, ConnectConfig } from 'ssh2';
import { Deferred } from 'ts-deferred';
import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { GPUInfo, GPUSummary } from '../common/gpuData'; import { GPUInfo, GPUSummary } from '../common/gpuData';
import { ShellExecutor } from './shellExecutor';
/** /**
* Metadata of remote machine for configuration and statuc query * Metadata of remote machine for configuration and statuc query
...@@ -72,7 +70,7 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail { ...@@ -72,7 +70,7 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail {
public gpuIndices: GPUInfo[]; public gpuIndices: GPUInfo[];
constructor(id: string, status: TrialJobStatus, submitTime: number, constructor(id: string, status: TrialJobStatus, submitTime: number,
workingDirectory: string, form: TrialJobApplicationForm) { workingDirectory: string, form: TrialJobApplicationForm) {
this.id = id; this.id = id;
this.status = status; this.status = status;
this.submitTime = submitTime; this.submitTime = submitTime;
...@@ -84,149 +82,88 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail { ...@@ -84,149 +82,88 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail {
} }
/** /**
* The remote machine ssh client used for trial and gpu detector * The remote machine executor manager
*/ */
export class SSHClient { export class ExecutorManager {
private readonly sshClient: Client; private readonly executorArray: ShellExecutor[];
private usedConnectionNumber: number; //count the connection number of every client
constructor(sshClient: Client, usedConnectionNumber: number) {
this.sshClient = sshClient;
this.usedConnectionNumber = usedConnectionNumber;
}
public get getSSHClientInstance(): Client {
return this.sshClient;
}
public get getUsedConnectionNumber(): number {
return this.usedConnectionNumber;
}
public addUsedConnectionNumber(): void {
this.usedConnectionNumber += 1;
}
public minusUsedConnectionNumber(): void {
this.usedConnectionNumber -= 1;
}
}
/**
* The remote machine ssh client manager
*/
export class SSHClientManager {
private readonly sshClientArray: SSHClient[];
private readonly maxTrialNumberPerConnection: number; private readonly maxTrialNumberPerConnection: number;
private readonly rmMeta: RemoteMachineMeta; private readonly rmMeta: RemoteMachineMeta;
constructor(sshClientArray: SSHClient[], maxTrialNumberPerConnection: number, rmMeta: RemoteMachineMeta) { constructor(executorArray: ShellExecutor[], maxTrialNumberPerConnection: number, rmMeta: RemoteMachineMeta) {
this.rmMeta = rmMeta; this.rmMeta = rmMeta;
this.sshClientArray = sshClientArray; this.executorArray = executorArray;
this.maxTrialNumberPerConnection = maxTrialNumberPerConnection; this.maxTrialNumberPerConnection = maxTrialNumberPerConnection;
} }
/** /**
* find a available ssh client in ssh array, if no ssh client available, return undefined * find a available executor, if no executor available, return a new one
*/ */
public async getAvailableSSHClient(): Promise<Client> { public async getAvailableExecutor(): Promise<ShellExecutor> {
const deferred: Deferred<Client> = new Deferred<Client>(); for (const index of this.executorArray.keys()) {
for (const index of this.sshClientArray.keys()) { const connectionNumber: number = this.executorArray[index].getUsedConnectionNumber;
const connectionNumber: number = this.sshClientArray[index].getUsedConnectionNumber;
if (connectionNumber < this.maxTrialNumberPerConnection) { if (connectionNumber < this.maxTrialNumberPerConnection) {
this.sshClientArray[index].addUsedConnectionNumber(); this.executorArray[index].addUsedConnectionNumber();
deferred.resolve(this.sshClientArray[index].getSSHClientInstance);
return deferred.promise; return this.executorArray[index];
} }
} }
//init a new ssh client if could not get an available one //init a new executor if could not get an available one
return this.initNewSSHClient(); return await this.initNewShellExecutor();
} }
/** /**
* add a new ssh client to sshClientArray * add a new executor to executorArray
* @param sshClient SSH Client * @param executor ShellExecutor
*/ */
public addNewSSHClient(client: Client): void { public addNewShellExecutor(executor: ShellExecutor): void {
this.sshClientArray.push(new SSHClient(client, 1)); this.executorArray.push(executor);
} }
/** /**
* first ssh client instance is used for gpu collector and host job * first executor instance is used for gpu collector and host job
*/ */
public getFirstSSHClient(): Client { public getFirstExecutor(): ShellExecutor {
return this.sshClientArray[0].getSSHClientInstance; return this.executorArray[0];
} }
/** /**
* close all of ssh client * close all of executor
*/ */
public closeAllSSHClient(): void { public closeAllExecutor(): void {
for (const sshClient of this.sshClientArray) { for (const executor of this.executorArray) {
sshClient.getSSHClientInstance.end(); executor.close();
} }
} }
/** /**
* retrieve resource, minus a number for given ssh client * retrieve resource, minus a number for given executor
* @param client SSH Client * @param executor executor
*/ */
public releaseConnection(client: Client | undefined): void { public releaseConnection(executor: ShellExecutor | undefined): void {
if (client === undefined) { if (executor === undefined) {
throw new Error(`could not release a undefined ssh client`); throw new Error(`could not release a undefined executor`);
} }
for (const index of this.sshClientArray.keys()) { for (const index of this.executorArray.keys()) {
if (this.sshClientArray[index].getSSHClientInstance === client) { if (this.executorArray[index] === executor) {
this.sshClientArray[index].minusUsedConnectionNumber(); this.executorArray[index].minusUsedConnectionNumber();
break; break;
} }
} }
} }
/** /**
* Create a new ssh connection client and initialize it * Create a new connection executor and initialize it
*/ */
private initNewSSHClient(): Promise<Client> { private async initNewShellExecutor(): Promise<ShellExecutor> {
const deferred: Deferred<Client> = new Deferred<Client>(); const executor = new ShellExecutor();
const conn: Client = new Client(); await executor.initialize(this.rmMeta);
const connectConfig: ConnectConfig = { return executor;
host: this.rmMeta.ip,
port: this.rmMeta.port,
username: this.rmMeta.username,
tryKeyboard: true };
if (this.rmMeta.passwd !== undefined) {
connectConfig.password = this.rmMeta.passwd;
} else if (this.rmMeta.sshKeyPath !== undefined) {
if (!fs.existsSync(this.rmMeta.sshKeyPath)) {
//SSh key path is not a valid file, reject
deferred.reject(new Error(`${this.rmMeta.sshKeyPath} does not exist.`));
}
const privateKey: string = fs.readFileSync(this.rmMeta.sshKeyPath, 'utf8');
connectConfig.privateKey = privateKey;
connectConfig.passphrase = this.rmMeta.passphrase;
} else {
deferred.reject(new Error(`No valid passwd or sshKeyPath is configed.`));
}
conn.on('ready', () => {
this.addNewSSHClient(conn);
deferred.resolve(conn);
})
.on('error', (err: Error) => {
// SSH connection error, reject with error message
deferred.reject(new Error(err.message));
}).on("keyboard-interactive", (name, instructions, lang, prompts, finish) => {
finish([this.rmMeta.passwd]);
})
.connect(connectConfig);
return deferred.promise;
} }
} }
export type RemoteMachineScheduleResult = { scheduleInfo: RemoteMachineScheduleInfo | undefined; resultType: ScheduleResultType}; export type RemoteMachineScheduleResult = { scheduleInfo: RemoteMachineScheduleInfo | undefined; resultType: ScheduleResultType };
export type RemoteMachineScheduleInfo = { rmMeta: RemoteMachineMeta; cudaVisibleDevice: string}; export type RemoteMachineScheduleInfo = { rmMeta: RemoteMachineMeta; cudaVisibleDevice: string };
export enum ScheduleResultType { export enum ScheduleResultType {
// Schedule succeeded // Schedule succeeded
...@@ -240,7 +177,7 @@ export enum ScheduleResultType { ...@@ -240,7 +177,7 @@ export enum ScheduleResultType {
} }
export const REMOTEMACHINE_TRIAL_COMMAND_FORMAT: string = export const REMOTEMACHINE_TRIAL_COMMAND_FORMAT: string =
`#!/bin/bash `#!/bin/bash
export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} \ export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} \
NNI_TRIAL_SEQ_ID={4} export MULTI_PHASE={5} NNI_TRIAL_SEQ_ID={4} export MULTI_PHASE={5}
cd $NNI_SYS_DIR cd $NNI_SYS_DIR
...@@ -251,7 +188,7 @@ python3 -m nni_trial_tool.trial_keeper --trial_command '{7}' --nnimanager_ip '{8 ...@@ -251,7 +188,7 @@ python3 -m nni_trial_tool.trial_keeper --trial_command '{7}' --nnimanager_ip '{8
echo $? \`date +%s%3N\` >{12}`; echo $? \`date +%s%3N\` >{12}`;
export const HOST_JOB_SHELL_FORMAT: string = export const HOST_JOB_SHELL_FORMAT: string =
`#!/bin/bash `#!/bin/bash
cd {0} cd {0}
echo $$ >{1} echo $$ >{1}
eval {2} >stdout 2>stderr eval {2} >stdout 2>stderr
......
...@@ -6,33 +6,159 @@ ...@@ -6,33 +6,159 @@
import * as assert from 'assert'; import * as assert from 'assert';
import * as os from 'os'; import * as os from 'os';
import * as path from 'path'; import * as path from 'path';
import { Client, ClientChannel, SFTPWrapper } from 'ssh2'; import * as fs from 'fs';
import { Client, ClientChannel, SFTPWrapper, ConnectConfig } from 'ssh2';
import { Deferred } from "ts-deferred";
import { RemoteCommandResult, RemoteMachineMeta } from "./remoteMachineData";
import * as stream from 'stream'; import * as stream from 'stream';
import { Deferred } from 'ts-deferred'; import { OsCommands } from "./osCommands";
import { NNIError, NNIErrorNames } from '../../common/errors'; import { LinuxCommands } from "./extends/linuxCommands";
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { getRemoteTmpDir, uniqueString, unixPathJoin } from '../../common/utils'; import { NNIError, NNIErrorNames } from '../../common/errors';
import { execRemove, tarAdd } from '../common/util'; import { execRemove, tarAdd } from '../common/util';
import { RemoteCommandResult } from './remoteMachineData'; import { getRemoteTmpDir, uniqueString, unixPathJoin } from '../../common/utils';
/** class ShellExecutor {
* private sshClient: Client = new Client();
* Utility for frequent operations towards SSH client private osCommands: OsCommands | undefined;
* private usedConnectionNumber: number = 0; //count the connection number of every client
*/
export namespace SSHClientUtility { protected pathSpliter: string = '/';
protected multiplePathSpliter: RegExp = new RegExp(`\\${this.pathSpliter}{2,}`);
public async initialize(rmMeta: RemoteMachineMeta): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
const connectConfig: ConnectConfig = {
host: rmMeta.ip,
port: rmMeta.port,
username: rmMeta.username,
tryKeyboard: true
};
if (rmMeta.passwd !== undefined) {
connectConfig.password = rmMeta.passwd;
} else if (rmMeta.sshKeyPath !== undefined) {
if (!fs.existsSync(rmMeta.sshKeyPath)) {
//SSh key path is not a valid file, reject
deferred.reject(new Error(`${rmMeta.sshKeyPath} does not exist.`));
}
const privateKey: string = fs.readFileSync(rmMeta.sshKeyPath, 'utf8');
connectConfig.privateKey = privateKey;
connectConfig.passphrase = rmMeta.passphrase;
} else {
deferred.reject(new Error(`No valid passwd or sshKeyPath is configed.`));
}
this.sshClient.on('ready', async () => {
// check OS type: windows or else
const result = await this.execute("ver");
if (result.exitCode == 0 && result.stdout.search("Windows") > -1) {
// not implement Windows commands yet.
throw new Error("not implement Windows commands yet.");
} else {
this.osCommands = new LinuxCommands();
}
deferred.resolve();
}).on('error', (err: Error) => {
// SSH connection error, reject with error message
deferred.reject(new Error(err.message));
}).on("keyboard-interactive", (name, instructions, lang, prompts, finish) => {
finish([rmMeta.passwd]);
}).connect(connectConfig);
return deferred.promise;
}
public close(): void {
this.sshClient.end();
}
public get getUsedConnectionNumber(): number {
return this.usedConnectionNumber;
}
public addUsedConnectionNumber(): void {
this.usedConnectionNumber += 1;
}
public minusUsedConnectionNumber(): void {
this.usedConnectionNumber -= 1;
}
public async createFolder(folderName: string, sharedFolder: boolean = false): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.createFolder(folderName, sharedFolder);
const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0;
return result;
}
public async allowPermission(isRecursive: boolean = false, ...folders: string[]): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.allowPermission(isRecursive, ...folders);
const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0;
return result;
}
public async removeFolder(folderName: string, isRecursive: boolean = false, isForce: boolean = true): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.removeFolder(folderName, isRecursive, isForce);
const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0;
return result;
}
public async removeFiles(folderOrFileName: string, filePattern: string = ""): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.removeFiles(folderOrFileName, filePattern);
const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0;
return result;
}
public async readLastLines(fileName: string, lineCount: number = 1): Promise<string> {
const commandText = this.osCommands && this.osCommands.readLastLines(fileName, lineCount);
const commandResult = await this.execute(commandText);
let result: string = "";
if (commandResult !== undefined && commandResult.stdout !== undefined && commandResult.stdout.length > 0) {
result = commandResult.stdout;
}
return result;
}
public async isProcessAlive(pidFileName: string): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.isProcessAliveCommand(pidFileName);
const commandResult = await this.execute(commandText);
const result = this.osCommands && this.osCommands.isProcessAliveProcessOutput(commandResult);
return result !== undefined ? result : false;
}
public async killChildProcesses(pidFileName: string): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.killChildProcesses(pidFileName);
const commandResult = await this.execute(commandText);
return commandResult.exitCode == 0;
}
public async extractFile(tarFileName: string, targetFolder: string): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.extractFile(tarFileName, targetFolder);
const commandResult = await this.execute(commandText);
return commandResult.exitCode == 0;
}
public async executeScript(script: string, isFile: boolean, isInteractive: boolean = false): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.executeScript(script, isFile);
const commandResult = await this.execute(commandText, undefined, isInteractive);
return commandResult.exitCode == 0;
}
/** /**
* Copy local file to remote path * Copy local file to remote path
* @param localFilePath the path of local file * @param localFilePath the path of local file
* @param remoteFilePath the target path in remote machine * @param remoteFilePath the target path in remote machine
* @param sshClient SSH Client
*/ */
export function copyFileToRemote(localFilePath: string, remoteFilePath: string, sshClient: Client): Promise<boolean> { public async copyFileToRemote(localFilePath: string, remoteFilePath: string): Promise<boolean> {
const log: Logger = getLogger(); const log: Logger = getLogger();
log.debug(`copyFileToRemote: localFilePath: ${localFilePath}, remoteFilePath: ${remoteFilePath}`); log.debug(`copyFileToRemote: localFilePath: ${localFilePath}, remoteFilePath: ${remoteFilePath}`);
assert(sshClient !== undefined);
const deferred: Deferred<boolean> = new Deferred<boolean>(); const deferred: Deferred<boolean> = new Deferred<boolean>();
sshClient.sftp((err: Error, sftp: SFTPWrapper) => { this.sshClient.sftp((err: Error, sftp: SFTPWrapper) => {
if (err !== undefined && err !== null) { if (err !== undefined && err !== null) {
log.error(`copyFileToRemote: ${err.message}, ${localFilePath}, ${remoteFilePath}`); log.error(`copyFileToRemote: ${err.message}, ${localFilePath}, ${remoteFilePath}`);
deferred.reject(err); deferred.reject(err);
...@@ -53,66 +179,13 @@ export namespace SSHClientUtility { ...@@ -53,66 +179,13 @@ export namespace SSHClientUtility {
return deferred.promise; return deferred.promise;
} }
/**
* Execute command on remote machine
* @param command the command to execute remotely
* @param client SSH Client
*/
export function remoteExeCommand(command: string, client: Client, useShell: boolean = false): Promise<RemoteCommandResult> {
const log: Logger = getLogger();
log.debug(`remoteExeCommand: command: [${command}]`);
const deferred: Deferred<RemoteCommandResult> = new Deferred<RemoteCommandResult>();
let stdout: string = '';
let stderr: string = '';
let exitCode: number;
const callback = (err: Error, channel: ClientChannel): void => {
if (err !== undefined && err !== null) {
log.error(`remoteExeCommand: ${err.message}`);
deferred.reject(err);
return;
}
channel.on('data', (data: any) => {
stdout += data;
});
channel.on('exit', (code: any) => {
exitCode = <number>code;
log.debug(`remoteExeCommand exit(${exitCode})\nstdout: ${stdout}\nstderr: ${stderr}`);
deferred.resolve({
stdout: stdout,
stderr: stderr,
exitCode: exitCode
});
});
channel.stderr.on('data', function (data) {
stderr += data;
});
if (useShell) {
channel.stdin.write(`${command}\n`);
channel.end("exit\n");
}
return;
};
if (useShell) {
client.shell(callback);
} else {
client.exec(command, callback);
}
return deferred.promise;
}
/** /**
* Copy files and directories in local directory recursively to remote directory * Copy files and directories in local directory recursively to remote directory
* @param localDirectory local diretory * @param localDirectory local diretory
* @param remoteDirectory remote directory * @param remoteDirectory remote directory
* @param sshClient SSH client * @param sshClient SSH client
*/ */
export async function copyDirectoryToRemote(localDirectory: string, remoteDirectory: string, sshClient: Client, remoteOS: string): Promise<void> { public async copyDirectoryToRemote(localDirectory: string, remoteDirectory: string, remoteOS: string): Promise<void> {
const tmpSuffix: string = uniqueString(5); const tmpSuffix: string = uniqueString(5);
const localTarPath: string = path.join(os.tmpdir(), `nni_tmp_local_${tmpSuffix}.tar.gz`); const localTarPath: string = path.join(os.tmpdir(), `nni_tmp_local_${tmpSuffix}.tar.gz`);
const remoteTarPath: string = unixPathJoin(getRemoteTmpDir(remoteOS), `nni_tmp_remote_${tmpSuffix}.tar.gz`); const remoteTarPath: string = unixPathJoin(getRemoteTmpDir(remoteOS), `nni_tmp_remote_${tmpSuffix}.tar.gz`);
...@@ -120,16 +193,16 @@ export namespace SSHClientUtility { ...@@ -120,16 +193,16 @@ export namespace SSHClientUtility {
// Compress files in local directory to experiment root directory // Compress files in local directory to experiment root directory
await tarAdd(localTarPath, localDirectory); await tarAdd(localTarPath, localDirectory);
// Copy the compressed file to remoteDirectory and delete it // Copy the compressed file to remoteDirectory and delete it
await copyFileToRemote(localTarPath, remoteTarPath, sshClient); await this.copyFileToRemote(localTarPath, remoteTarPath);
await execRemove(localTarPath); await execRemove(localTarPath);
// Decompress the remote compressed file in and delete it // Decompress the remote compressed file in and delete it
await remoteExeCommand(`tar -oxzf ${remoteTarPath} -C ${remoteDirectory}`, sshClient); await this.extractFile(remoteTarPath, remoteDirectory);
await remoteExeCommand(`rm ${remoteTarPath}`, sshClient); await this.removeFiles(remoteTarPath);
} }
export function getRemoteFileContent(filePath: string, sshClient: Client): Promise<string> { public async getRemoteFileContent(filePath: string): Promise<string> {
const deferred: Deferred<string> = new Deferred<string>(); const deferred: Deferred<string> = new Deferred<string>();
sshClient.sftp((err: Error, sftp: SFTPWrapper) => { this.sshClient.sftp((err: Error, sftp: SFTPWrapper) => {
if (err !== undefined && err !== null) { if (err !== undefined && err !== null) {
getLogger() getLogger()
.error(`getRemoteFileContent: ${err.message}`); .error(`getRemoteFileContent: ${err.message}`);
...@@ -163,4 +236,59 @@ export namespace SSHClientUtility { ...@@ -163,4 +236,59 @@ export namespace SSHClientUtility {
return deferred.promise; return deferred.promise;
} }
private async execute(command: string | undefined, processOutput: ((input: RemoteCommandResult) => RemoteCommandResult) | undefined = undefined, useShell: boolean = false): Promise<RemoteCommandResult> {
const log: Logger = getLogger();
log.debug(`remoteExeCommand: command: [${command}]`);
const deferred: Deferred<RemoteCommandResult> = new Deferred<RemoteCommandResult>();
let stdout: string = '';
let stderr: string = '';
let exitCode: number;
const callback = (err: Error, channel: ClientChannel): void => {
if (err !== undefined && err !== null) {
log.error(`remoteExeCommand: ${err.message}`);
deferred.reject(err);
return;
}
channel.on('data', (data: any) => {
stdout += data;
});
channel.on('exit', (code: any) => {
exitCode = <number>code;
log.debug(`remoteExeCommand exit(${exitCode})\nstdout: ${stdout}\nstderr: ${stderr}`);
let result = {
stdout: stdout,
stderr: stderr,
exitCode: exitCode
};
if (processOutput != undefined) {
result = processOutput(result);
}
deferred.resolve(result);
});
channel.stderr.on('data', function (data) {
stderr += data;
});
if (useShell) {
channel.stdin.write(`${command}\n`);
channel.end("exit\n");
}
return;
};
if (useShell) {
this.sshClient.shell(callback);
} else {
this.sshClient.exec(command !== undefined ? command : "", callback);
}
return deferred.promise;
}
} }
export { ShellExecutor };
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as chai from 'chai';
import * as chaiAsPromised from 'chai-as-promised';
import * as component from '../../../common/component';
import { cleanupUnitTest, prepareUnitTest } from '../../../common/utils';
import { LinuxCommands } from '../extends/linuxCommands';
// import { TrialConfigMetadataKey } from '../trialConfigMetadataKey';
describe('Unit Test for linuxCommands', () => {
let linuxCommands: LinuxCommands
before(() => {
chai.should();
chai.use(chaiAsPromised);
prepareUnitTest();
});
after(() => {
cleanupUnitTest();
});
beforeEach(() => {
linuxCommands = component.get(LinuxCommands);
});
afterEach(() => {
});
it('joinPath', async () => {
chai.expect(linuxCommands.joinPath("/root/", "/first")).to.equal("/root/first");
chai.expect(linuxCommands.joinPath("/root", "first")).to.equal("/root/first");
chai.expect(linuxCommands.joinPath("/root/", "first")).to.equal("/root/first");
chai.expect(linuxCommands.joinPath("root/", "first")).to.equal("root/first");
chai.expect(linuxCommands.joinPath("root/")).to.equal("root/");
chai.expect(linuxCommands.joinPath("root")).to.equal("root");
chai.expect(linuxCommands.joinPath("./root")).to.equal("./root");
chai.expect(linuxCommands.joinPath("")).to.equal(".");
chai.expect(linuxCommands.joinPath("..")).to.equal("..");
})
it('createFolder', async () => {
chai.expect(linuxCommands.createFolder("test")).to.equal("mkdir -p 'test'");
chai.expect(linuxCommands.createFolder("test", true)).to.equal("umask 0; mkdir -p 'test'");
})
it('allowPermission', async () => {
chai.expect(linuxCommands.allowPermission(true, "test", "test1")).to.equal("chmod 777 -R 'test' 'test1'");
chai.expect(linuxCommands.allowPermission(false, "test")).to.equal("chmod 777 'test'");
})
it('removeFolder', async () => {
chai.expect(linuxCommands.removeFolder("test")).to.equal("rm -df 'test'");
chai.expect(linuxCommands.removeFolder("test", true)).to.equal("rm -rf 'test'");
chai.expect(linuxCommands.removeFolder("test", true, false)).to.equal("rm -r 'test'");
chai.expect(linuxCommands.removeFolder("test", false, false)).to.equal("rm 'test'");
})
it('removeFiles', async () => {
chai.expect(linuxCommands.removeFiles("test", "*.sh")).to.equal("rm 'test/*.sh'");
chai.expect(linuxCommands.removeFiles("test", "")).to.equal("rm 'test'");
})
it('readLastLines', async () => {
chai.expect(linuxCommands.readLastLines("test", 3)).to.equal("tail -n 3 'test'");
})
it('isProcessAlive', async () => {
chai.expect(linuxCommands.isProcessAliveCommand("test")).to.equal("kill -0 `cat 'test'`");
chai.expect(linuxCommands.isProcessAliveProcessOutput(
{
exitCode: 0,
stdout: "",
stderr: ""
}
)).to.equal(true);
chai.expect(linuxCommands.isProcessAliveProcessOutput(
{
exitCode: 10,
stdout: "",
stderr: ""
}
)).to.equal(false);
})
it('killChildProcesses', async () => {
chai.expect(linuxCommands.killChildProcesses("test")).to.equal("pkill -P `cat 'test'`");
})
it('extractFile', async () => {
chai.expect(linuxCommands.extractFile("test.tar", "testfolder")).to.equal("tar -oxzf 'test.tar' -C 'testfolder'");
})
it('executeScript', async () => {
chai.expect(linuxCommands.executeScript("test.sh", true)).to.equal("bash 'test.sh'");
chai.expect(linuxCommands.executeScript("test script'\"", false)).to.equal(`bash -c \"test script'\\""`);
})
});
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as cpp from 'child-process-promise';
import * as fs from 'fs';
import * as chai from 'chai';
import * as chaiAsPromised from 'chai-as-promised';
import { Client } from 'ssh2';
import { ShellExecutor } from '../shellExecutor';
import { prepareUnitTest, cleanupUnitTest } from '../../../common/utils';
const LOCALFILE: string = '/tmp/localSshclientUTData';
const REMOTEFILE: string = '/tmp/remoteSshclientUTData';
const REMOTEFOLDER: string = '/tmp/remoteSshclientUTFolder';
async function copyFile(executor: ShellExecutor): Promise<void> {
await executor.copyFileToRemote(LOCALFILE, REMOTEFILE);
}
async function copyFileToRemoteLoop(executor: ShellExecutor): Promise<void> {
for (let i: number = 0; i < 10; i++) {
// console.log(i);
await executor.copyFileToRemote(LOCALFILE, REMOTEFILE);
}
}
async function getRemoteFileContentLoop(executor: ShellExecutor): Promise<void> {
for (let i: number = 0; i < 10; i++) {
// console.log(i);
await executor.getRemoteFileContent(REMOTEFILE);
}
}
describe('ShellExecutor test', () => {
let skip: boolean = false;
let rmMeta: any;
try {
rmMeta = JSON.parse(fs.readFileSync('../../.vscode/rminfo.json', 'utf8'));
console.log(rmMeta);
} catch (err) {
console.log(`Please configure rminfo.json to enable remote machine test.${err}`);
skip = true;
}
before(async () => {
chai.should();
chai.use(chaiAsPromised);
await cpp.exec(`echo '1234' > ${LOCALFILE}`);
prepareUnitTest();
});
after(() => {
cleanupUnitTest();
fs.unlinkSync(LOCALFILE);
});
it('Test mkdir', async () => {
if (skip) {
return;
}
const shellExecutor: ShellExecutor = new ShellExecutor();
await shellExecutor.initialize(rmMeta);
let result = await shellExecutor.createFolder(REMOTEFOLDER, false);
chai.expect(result).eq(true);
result = await shellExecutor.removeFolder(REMOTEFOLDER);
chai.expect(result).eq(true);
});
it('Test ShellExecutor', async () => {
if (skip) {
return;
}
const shellExecutor: ShellExecutor = new ShellExecutor();
await shellExecutor.initialize(rmMeta);
await copyFile(shellExecutor);
await Promise.all([
copyFileToRemoteLoop(shellExecutor),
copyFileToRemoteLoop(shellExecutor),
copyFileToRemoteLoop(shellExecutor),
getRemoteFileContentLoop(shellExecutor)
]);
});
});
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as cpp from 'child-process-promise';
import * as fs from 'fs';
import { Client } from 'ssh2';
import { Deferred } from 'ts-deferred';
import { SSHClientUtility } from '../remote_machine/sshClientUtility';
const LOCALFILE: string = '/tmp/sshclientUTData';
const REMOTEFILE: string = '/tmp/sshclientUTData';
async function copyFile(conn: Client): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
conn.sftp((err, sftp) => {
if (err) {
deferred.reject(err);
return;
}
sftp.fastPut(
LOCALFILE,
REMOTEFILE, (fastPutErr: Error) => {
sftp.end();
if (fastPutErr) {
deferred.reject(fastPutErr);
} else {
deferred.resolve();
}
}
);
});
return deferred.promise;
}
async function copyFileToRemoteLoop(conn: Client): Promise<void> {
for (let i: number = 0; i < 500; i++) {
console.log(i);
await SSHClientUtility.copyFileToRemote(LOCALFILE, REMOTEFILE, conn);
}
}
async function remoteExeCommandLoop(conn: Client): Promise<void> {
for (let i: number = 0; i < 500; i++) {
console.log(i);
await SSHClientUtility.remoteExeCommand('ls', conn);
}
}
async function getRemoteFileContentLoop(conn: Client): Promise<void> {
for (let i: number = 0; i < 500; i++) {
console.log(i);
await SSHClientUtility.getRemoteFileContent(REMOTEFILE, conn);
}
}
describe('sshClientUtility test', () => {
let skip: boolean = true;
let rmMeta: any;
try {
rmMeta = JSON.parse(fs.readFileSync('../../.vscode/rminfo.json', 'utf8'));
} catch (err) {
skip = true;
}
before(async () => {
await cpp.exec(`echo '1234' > ${LOCALFILE}`);
});
after(() => {
fs.unlinkSync(LOCALFILE);
});
it('Test SSHClientUtility', (done) => {
if (skip) {
done();
return;
}
const conn: Client = new Client();
conn.on('ready', async () => {
await copyFile(conn);
await Promise.all([
copyFileToRemoteLoop(conn),
copyFileToRemoteLoop(conn),
copyFileToRemoteLoop(conn),
remoteExeCommandLoop(conn),
getRemoteFileContentLoop(conn)
]);
done();
}).connect(rmMeta);
});
});
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment