"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "cffdc27c05e9ac0b7af36dfa0391b4b1570c5c75"
Unverified Commit de9e2842 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Fix ssh connection error (#829)

SSH client has a max number of open channels for a connection, if we set the number of trialCurrency too big, our ssh client will exec command using ssh frequently, then we will meet the error of Error: (SSH) Channel open failure: open failed.
Refactor the code, set one connection has a max trial concurrency, when the number of trial reach the ssh connection restriction, will create a new ssh connection to exec trial commands.
parent 7d91796c
......@@ -24,21 +24,21 @@ import { Client } from 'ssh2';
import { getLogger, Logger } from '../../common/log';
import { randomSelect } from '../../common/utils';
import { GPUInfo } from '../common/gpuData';
import { RemoteMachineMeta, RemoteMachineScheduleResult, ScheduleResultType } from './remoteMachineData';
import { RemoteMachineMeta, RemoteMachineScheduleResult, ScheduleResultType, SSHClientManager } from './remoteMachineData';
/**
* A simple GPU scheduler implementation
*/
export class GPUScheduler {
private readonly machineSSHClientMap : Map<RemoteMachineMeta, Client>;
private readonly machineSSHClientMap : Map<RemoteMachineMeta, SSHClientManager>;
private log: Logger = getLogger();
/**
* Constructor
* @param machineSSHClientMap map from remote machine to sshClient
*/
constructor(machineSSHClientMap : Map<RemoteMachineMeta, Client>) {
constructor(machineSSHClientMap : Map<RemoteMachineMeta, SSHClientManager>) {
this.machineSSHClientMap = machineSSHClientMap;
}
......@@ -113,7 +113,7 @@ export class GPUScheduler {
*/
private gpuResourceDetection() : Map<RemoteMachineMeta, GPUInfo[]> {
const totalResourceMap : Map<RemoteMachineMeta, GPUInfo[]> = new Map<RemoteMachineMeta, GPUInfo[]>();
this.machineSSHClientMap.forEach((client: Client, rmMeta: RemoteMachineMeta) => {
this.machineSSHClientMap.forEach((sshClientManager: SSHClientManager, rmMeta: RemoteMachineMeta) => {
// Assgin totoal GPU count as init available GPU number
if (rmMeta.gpuSummary !== undefined) {
const availableGPUs: GPUInfo[] = [];
......
......@@ -21,6 +21,9 @@
import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { GPUSummary } from '../common/gpuData';
import { Client, ConnectConfig } from 'ssh2';
import { Deferred } from 'ts-deferred';
import * as fs from 'fs';
/**
......@@ -94,6 +97,138 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail {
}
}
/**
* The remote machine ssh client used for trial and gpu detector
*/
export class SSHClient {
private readonly sshClient: Client;
private usedConnectionNumber: number; //count the connection number of every client
constructor(sshClient: Client, usedConnectionNumber: number) {
this.sshClient = sshClient;
this.usedConnectionNumber = usedConnectionNumber;
}
public get getSSHClientInstance(): Client {
return this.sshClient;
}
public get getUsedConnectionNumber(): number {
return this.usedConnectionNumber;
}
public addUsedConnectionNumber() {
this.usedConnectionNumber += 1;
}
public minusUsedConnectionNumber() {
this.usedConnectionNumber -= 1;
}
}
export class SSHClientManager {
private sshClientArray: SSHClient[];
private readonly maxTrialNumberPerConnection: number;
private readonly rmMeta: RemoteMachineMeta;
constructor(sshClientArray: SSHClient[], maxTrialNumberPerConnection: number, rmMeta: RemoteMachineMeta) {
this.rmMeta = rmMeta;
this.sshClientArray = sshClientArray;
this.maxTrialNumberPerConnection = maxTrialNumberPerConnection;
}
/**
* Create a new ssh connection client and initialize it
*/
private initNewSSHClient(): Promise<Client> {
const deferred: Deferred<Client> = new Deferred<Client>();
const conn: Client = new Client();
let connectConfig: ConnectConfig = {
host: this.rmMeta.ip,
port: this.rmMeta.port,
username: this.rmMeta.username };
if (this.rmMeta.passwd) {
connectConfig.password = this.rmMeta.passwd;
} else if(this.rmMeta.sshKeyPath) {
if(!fs.existsSync(this.rmMeta.sshKeyPath)) {
//SSh key path is not a valid file, reject
deferred.reject(new Error(`${this.rmMeta.sshKeyPath} does not exist.`));
}
const privateKey: string = fs.readFileSync(this.rmMeta.sshKeyPath, 'utf8');
connectConfig.privateKey = privateKey;
connectConfig.passphrase = this.rmMeta.passphrase;
} else {
deferred.reject(new Error(`No valid passwd or sshKeyPath is configed.`));
}
conn.on('ready', () => {
this.addNewSSHClient(conn);
deferred.resolve(conn);
}).on('error', (err: Error) => {
// SSH connection error, reject with error message
deferred.reject(new Error(err.message));
}).connect(connectConfig);
return deferred.promise;
}
/**
* find a available ssh client in ssh array, if no ssh client available, return undefined
*/
public async getAvailableSSHClient(): Promise<Client> {
const deferred: Deferred<Client> = new Deferred<Client>();
for (const index in this.sshClientArray) {
let connectionNumber: number = this.sshClientArray[index].getUsedConnectionNumber;
if(connectionNumber < this.maxTrialNumberPerConnection) {
this.sshClientArray[index].addUsedConnectionNumber();
deferred.resolve(this.sshClientArray[index].getSSHClientInstance);
return deferred.promise;
}
};
//init a new ssh client if could not get an available one
return await this.initNewSSHClient();
}
/**
* add a new ssh client to sshClientArray
* @param sshClient
*/
public addNewSSHClient(client: Client) {
this.sshClientArray.push(new SSHClient(client, 1));
}
/**
* first ssh clilent instance is used for gpu collector and host job
*/
public getFirstSSHClient() {
return this.sshClientArray[0].getSSHClientInstance;
}
/**
* close all of ssh client
*/
public closeAllSSHClient() {
for (let sshClient of this.sshClientArray) {
sshClient.getSSHClientInstance.end();
}
}
/**
* retrieve resource, minus a number for given ssh client
* @param client
*/
public releaseConnection(client: Client | undefined) {
if(!client) {
throw new Error(`could not release a undefined ssh client`);
}
for(let index in this.sshClientArray) {
if(this.sshClientArray[index].getSSHClientInstance === client) {
this.sshClientArray[index].minusUsedConnectionNumber();
break;
}
}
}
}
export type RemoteMachineScheduleResult = { scheduleInfo : RemoteMachineScheduleInfo | undefined; resultType : ScheduleResultType};
export type RemoteMachineScheduleInfo = { rmMeta : RemoteMachineMeta; cuda_visible_device : string};
......
......@@ -43,7 +43,7 @@ import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { GPUScheduler } from './gpuScheduler';
import {
HOST_JOB_SHELL_FORMAT, RemoteCommandResult, RemoteMachineMeta,
RemoteMachineScheduleInfo, RemoteMachineScheduleResult,
RemoteMachineScheduleInfo, RemoteMachineScheduleResult, SSHClient, SSHClientManager,
RemoteMachineTrialJobDetail, ScheduleResultType, REMOTEMACHINE_TRIAL_COMMAND_FORMAT,
GPU_COLLECTOR_FORMAT
} from './remoteMachineData';
......@@ -58,8 +58,10 @@ import { mkDirP } from '../../common/utils';
*/
@component.Singleton
class RemoteMachineTrainingService implements TrainingService {
private machineSSHClientMap: Map<RemoteMachineMeta, Client>;
private machineSSHClientMap: Map<RemoteMachineMeta, SSHClientManager>; //machine ssh client map
private trialSSHClientMap: Map<string, Client>; //trial ssh client map
private trialJobsMap: Map<string, RemoteMachineTrialJobDetail>;
private readonly MAX_TRIAL_NUMBER_PER_SSHCONNECTION: number = 5 // every ssh client has a max trial concurrency number
private expRootDir: string;
private remoteExpRootDir: string;
private trialConfig: TrialConfig | undefined;
......@@ -79,7 +81,8 @@ class RemoteMachineTrainingService implements TrainingService {
this.remoteOS = 'linux';
this.metricsEmitter = new EventEmitter();
this.trialJobsMap = new Map<string, RemoteMachineTrialJobDetail>();
this.machineSSHClientMap = new Map<RemoteMachineMeta, Client>();
this.trialSSHClientMap = new Map<string, Client>();
this.machineSSHClientMap = new Map<RemoteMachineMeta, SSHClientManager>();
this.gpuScheduler = new GPUScheduler(this.machineSSHClientMap);
this.jobQueue = [];
this.expRootDir = getExperimentRootDir();
......@@ -116,6 +119,40 @@ class RemoteMachineTrainingService implements TrainingService {
this.log.info('Remote machine training service exit.');
}
/**
* give trial a ssh connection
* @param trial
*/
public async allocateSSHClientForTrial(trial: RemoteMachineTrialJobDetail): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
if(!trial.rmMeta) {
throw new Error(`rmMeta not set in trial ${trial.id}`);
}
let sshClientManager: SSHClientManager | undefined = this.machineSSHClientMap.get(trial.rmMeta);
if(!sshClientManager) {
throw new Error(`remoteSSHClient not initialized`);
}
let sshClient: Client = await sshClientManager.getAvailableSSHClient();
this.trialSSHClientMap.set(trial.id, sshClient);
deferred.resolve();
return deferred.promise;
}
/**
* If a trial is finished, release the connection resource
* @param trial
*/
public releaseTrialSSHClient(trial: RemoteMachineTrialJobDetail): void {
if(!trial.rmMeta) {
throw new Error(`rmMeta not set in trial ${trial.id}`);
}
let sshClientManager: SSHClientManager | undefined = this.machineSSHClientMap.get(trial.rmMeta);
if(!sshClientManager) {
throw new Error(`sshClientManager not initialized`);
}
sshClientManager.releaseConnection(this.trialSSHClientMap.get(trial.id));
}
/**
* List submitted trial jobs
*/
......@@ -148,7 +185,7 @@ class RemoteMachineTrainingService implements TrainingService {
if (trialJob.rmMeta === undefined) {
throw new Error(`rmMeta not set for submitted job ${trialJobId}`);
}
const sshClient: Client | undefined = this.machineSSHClientMap.get(trialJob.rmMeta);
const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJob.id);
if (!sshClient) {
throw new Error(`Invalid job id: ${trialJobId}, cannot find ssh client`);
}
......@@ -179,7 +216,7 @@ class RemoteMachineTrainingService implements TrainingService {
* Submit trial job
* @param form trial job description form
*/
public submitTrialJob(form: JobApplicationForm): Promise<TrialJobDetail> {
public async submitTrialJob(form: JobApplicationForm): Promise<TrialJobDetail> {
if (!this.trialConfig) {
throw new Error('trial config is not initialized');
}
......@@ -271,7 +308,7 @@ class RemoteMachineTrainingService implements TrainingService {
// Get ssh client where the job is running
if (trialJob.rmMeta !== undefined) {
// If the trial job is already scheduled, check its status and kill the trial process in remote machine
const sshClient: Client | undefined = this.machineSSHClientMap.get(trialJob.rmMeta);
const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJob.id);
if (!sshClient) {
deferred.reject();
throw new Error(`Invalid job id ${trialJobId}, cannot find ssh client`);
......@@ -282,6 +319,7 @@ class RemoteMachineTrainingService implements TrainingService {
// Mark the toEarlyStop tag here
trialJob.isEarlyStopped = isEarlyStopped;
await SSHClientUtility.remoteExeCommand(`pkill -P \`cat ${jobpidPath}\``, sshClient);
this.releaseTrialSSHClient(trialJob);
} catch (error) {
// Not handle the error since pkill failed will not impact trial job's current status
this.log.error(`remoteTrainingService.cancelTrialJob: ${error.message}`);
......@@ -364,11 +402,15 @@ class RemoteMachineTrainingService implements TrainingService {
*/
private async cleanupConnections(): Promise<void> {
try{
for (const [rmMeta, client] of this.machineSSHClientMap.entries()) {
for (const [rmMeta, sshClientManager] of this.machineSSHClientMap.entries()) {
let jobpidPath: string = path.join(this.getRemoteScriptsPath(rmMeta.username), 'pid');
let client: Client | undefined = sshClientManager.getFirstSSHClient();
if(client) {
await SSHClientUtility.remoteExeCommand(`pkill -P \`cat ${jobpidPath}\``, client);
await SSHClientUtility.remoteExeCommand(`rm -rf ${this.getRemoteScriptsPath(rmMeta.username)}`, client);
}
sshClientManager.closeAllSSHClient();
}
}catch (error) {
//ignore error, this function is called to cleanup remote connections when experiment is stopping
this.log.error(`Cleanup connection exception, error is ${error.message}`);
......@@ -410,37 +452,14 @@ class RemoteMachineTrainingService implements TrainingService {
const rmMetaList: RemoteMachineMeta[] = <RemoteMachineMeta[]>JSON.parse(machineList);
let connectedRMNum: number = 0;
rmMetaList.forEach((rmMeta: RemoteMachineMeta) => {
const conn: Client = new Client();
let connectConfig: ConnectConfig = {
host: rmMeta.ip,
port: rmMeta.port,
username: rmMeta.username };
if (rmMeta.passwd) {
connectConfig.password = rmMeta.passwd;
} else if(rmMeta.sshKeyPath) {
if(!fs.existsSync(rmMeta.sshKeyPath)) {
//SSh key path is not a valid file, reject
deferred.reject(new Error(`${rmMeta.sshKeyPath} does not exist.`));
}
const privateKey: string = fs.readFileSync(rmMeta.sshKeyPath, 'utf8');
connectConfig.privateKey = privateKey;
connectConfig.passphrase = rmMeta.passphrase;
} else {
deferred.reject(new Error(`No valid passwd or sshKeyPath is configed.`));
}
this.machineSSHClientMap.set(rmMeta, conn);
conn.on('ready', async () => {
this.machineSSHClientMap.set(rmMeta, conn);
await this.initRemoteMachineOnConnected(rmMeta, conn);
rmMetaList.forEach(async (rmMeta: RemoteMachineMeta) => {
let sshClientManager: SSHClientManager = new SSHClientManager([], this.MAX_TRIAL_NUMBER_PER_SSHCONNECTION, rmMeta);
let sshClient: Client = await sshClientManager.getAvailableSSHClient();
this.machineSSHClientMap.set(rmMeta, sshClientManager);
await this.initRemoteMachineOnConnected(rmMeta, sshClient);
if (++connectedRMNum === rmMetaList.length) {
deferred.resolve();
}
}).on('error', (err: Error) => {
// SSH connection error, reject with error message
deferred.reject(new Error(err.message));
}).connect(connectConfig);
});
return deferred.promise;
}
......@@ -499,13 +518,16 @@ class RemoteMachineTrainingService implements TrainingService {
&& rmScheduleResult.scheduleInfo !== undefined) {
const rmScheduleInfo : RemoteMachineScheduleInfo = rmScheduleResult.scheduleInfo;
const trialWorkingFolder: string = path.join(this.remoteExpRootDir, 'trials', trialJobId);
trialJobDetail.rmMeta = rmScheduleInfo.rmMeta;
await this.allocateSSHClientForTrial(trialJobDetail);
await this.launchTrialOnScheduledMachine(
trialJobId, trialWorkingFolder, <TrialJobApplicationForm>trialJobDetail.form, rmScheduleInfo);
trialJobDetail.status = 'RUNNING';
trialJobDetail.url = `file://${rmScheduleInfo.rmMeta.ip}:${trialWorkingFolder}`;
trialJobDetail.startTime = Date.now();
trialJobDetail.rmMeta = rmScheduleInfo.rmMeta;
deferred.resolve(true);
} else if (rmScheduleResult.resultType === ScheduleResultType.TMP_NO_AVAILABLE_GPU) {
......@@ -524,7 +546,7 @@ class RemoteMachineTrainingService implements TrainingService {
throw new Error('trial config is not initialized');
}
const cuda_visible_device: string = rmScheduleInfo.cuda_visible_device;
const sshClient: Client | undefined = this.machineSSHClientMap.get(rmScheduleInfo.rmMeta);
const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJobId);
if (sshClient === undefined) {
assert(false, 'sshClient is undefined.');
......@@ -592,10 +614,11 @@ class RemoteMachineTrainingService implements TrainingService {
private async runHostJob(form: HostJobApplicationForm): Promise<TrialJobDetail> {
const rmMeta: RemoteMachineMeta = this.getRmMetaByHost(form.host);
const sshClient: Client | undefined = this.machineSSHClientMap.get(rmMeta);
if (sshClient === undefined) {
const sshClientManager: SSHClientManager | undefined = this.machineSSHClientMap.get(rmMeta);
if (sshClientManager === undefined) {
throw new Error('sshClient not found.');
}
let sshClient: Client = sshClientManager.getFirstSSHClient();
const jobId: string = uniqueString(5);
const localDir: string = path.join(this.expRootDir, 'hostjobs-local', jobId);
const remoteDir: string = this.getHostJobRemoteDir(jobId);
......@@ -654,6 +677,7 @@ class RemoteMachineTrainingService implements TrainingService {
}
}
trialJob.endTime = parseInt(timestamp, 10);
this.releaseTrialSSHClient(trialJob);
}
this.log.debug(`trailJob status update: ${trialJob.id}, ${trialJob.status}`);
}
......@@ -705,7 +729,7 @@ class RemoteMachineTrainingService implements TrainingService {
}
private async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters, rmMeta: RemoteMachineMeta): Promise<void> {
const sshClient: Client | undefined = this.machineSSHClientMap.get(rmMeta);
const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJobId);
if (sshClient === undefined) {
throw new Error('sshClient is undefined.');
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment