Unverified Commit 611a45fc authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Merge pull request #19 from microsoft/master

pull code
parents 841d4677 e267a737
...@@ -15,7 +15,7 @@ $yarnUrl = "https://yarnpkg.com/latest.tar.gz" ...@@ -15,7 +15,7 @@ $yarnUrl = "https://yarnpkg.com/latest.tar.gz"
$unzipNodeDir = "node-v*" $unzipNodeDir = "node-v*"
$unzipYarnDir = "yarn-v*" $unzipYarnDir = "yarn-v*"
$NNI_DEPENDENCY_FOLDER = "C:\tmp\$env:USERNAME" $NNI_DEPENDENCY_FOLDER = [System.IO.Path]::GetTempPath()+$env:USERNAME
$WHICH_PYTHON = where.exe python $WHICH_PYTHON = where.exe python
if($WHICH_PYTHON -eq $null){ if($WHICH_PYTHON -eq $null){
......
...@@ -70,6 +70,18 @@ interface TrialJobInfo { ...@@ -70,6 +70,18 @@ interface TrialJobInfo {
stderrPath?: string; stderrPath?: string;
} }
interface HyperParameterFormat {
parameter_source: string;
parameters: Object;
parameter_id: number;
}
interface ExportedDataFormat {
parameter: Object;
value: Object;
id: string;
}
abstract class DataStore { abstract class DataStore {
public abstract init(): Promise<void>; public abstract init(): Promise<void>;
public abstract close(): Promise<void>; public abstract close(): Promise<void>;
...@@ -82,6 +94,8 @@ abstract class DataStore { ...@@ -82,6 +94,8 @@ abstract class DataStore {
public abstract getTrialJob(trialJobId: string): Promise<TrialJobInfo>; public abstract getTrialJob(trialJobId: string): Promise<TrialJobInfo>;
public abstract storeMetricData(trialJobId: string, data: string): Promise<void>; public abstract storeMetricData(trialJobId: string, data: string): Promise<void>;
public abstract getMetricData(trialJobId?: string, metricType?: MetricType): Promise<MetricDataRecord[]>; public abstract getMetricData(trialJobId?: string, metricType?: MetricType): Promise<MetricDataRecord[]>;
public abstract exportTrialHpConfigs(): Promise<string>;
public abstract getImportedData(): Promise<string[]>;
} }
abstract class Database { abstract class Database {
...@@ -99,5 +113,5 @@ abstract class Database { ...@@ -99,5 +113,5 @@ abstract class Database {
export { export {
DataStore, Database, TrialJobEvent, MetricType, MetricData, TrialJobInfo, DataStore, Database, TrialJobEvent, MetricType, MetricData, TrialJobInfo,
ExperimentProfileRecord, TrialJobEventRecord, MetricDataRecord ExperimentProfileRecord, TrialJobEventRecord, MetricDataRecord, HyperParameterFormat, ExportedDataFormat
}; };
...@@ -100,6 +100,7 @@ abstract class Manager { ...@@ -100,6 +100,7 @@ abstract class Manager {
public abstract getExperimentProfile(): Promise<ExperimentProfile>; public abstract getExperimentProfile(): Promise<ExperimentProfile>;
public abstract updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void>; public abstract updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void>;
public abstract importData(data: string): Promise<void>; public abstract importData(data: string): Promise<void>;
public abstract exportData(): Promise<string>;
public abstract addCustomizedTrialJob(hyperParams: string): Promise<void>; public abstract addCustomizedTrialJob(hyperParams: string): Promise<void>;
public abstract cancelTrialJobByUser(trialJobId: string): Promise<void>; public abstract cancelTrialJobByUser(trialJobId: string): Promise<void>;
......
...@@ -43,11 +43,11 @@ function getExperimentRootDir(): string { ...@@ -43,11 +43,11 @@ function getExperimentRootDir(): string {
.getLogDir(); .getLogDir();
} }
function getLogDir(): string{ function getLogDir(): string {
return path.join(getExperimentRootDir(), 'log'); return path.join(getExperimentRootDir(), 'log');
} }
function getLogLevel(): string{ function getLogLevel(): string {
return getExperimentStartupInfo() return getExperimentStartupInfo()
.getLogLevel(); .getLogLevel();
} }
...@@ -149,7 +149,7 @@ function parseArg(names: string[]): string { ...@@ -149,7 +149,7 @@ function parseArg(names: string[]): string {
return ''; return '';
} }
function encodeCmdLineArgs(args:any):any{ function encodeCmdLineArgs(args: any): any {
if(process.platform === 'win32'){ if(process.platform === 'win32'){
return JSON.stringify(args); return JSON.stringify(args);
} }
...@@ -158,7 +158,7 @@ function encodeCmdLineArgs(args:any):any{ ...@@ -158,7 +158,7 @@ function encodeCmdLineArgs(args:any):any{
} }
} }
function getCmdPy():string{ function getCmdPy(): string {
let cmd = 'python3'; let cmd = 'python3';
if(process.platform === 'win32'){ if(process.platform === 'win32'){
cmd = 'python'; cmd = 'python';
...@@ -390,7 +390,7 @@ async function getVersion(): Promise<string> { ...@@ -390,7 +390,7 @@ async function getVersion(): Promise<string> {
/** /**
* run command as ChildProcess * run command as ChildProcess
*/ */
function getTunerProc(command: string, stdio: StdioOptions, newCwd: string, newEnv: any): ChildProcess{ function getTunerProc(command: string, stdio: StdioOptions, newCwd: string, newEnv: any): ChildProcess {
let cmd: string = command; let cmd: string = command;
let arg: string[] = []; let arg: string[] = [];
let newShell: boolean = true; let newShell: boolean = true;
...@@ -411,7 +411,7 @@ function getTunerProc(command: string, stdio: StdioOptions, newCwd: string, newE ...@@ -411,7 +411,7 @@ function getTunerProc(command: string, stdio: StdioOptions, newCwd: string, newE
/** /**
* judge whether the process is alive * judge whether the process is alive
*/ */
async function isAlive(pid:any): Promise<boolean>{ async function isAlive(pid:any): Promise<boolean> {
let deferred : Deferred<boolean> = new Deferred<boolean>(); let deferred : Deferred<boolean> = new Deferred<boolean>();
let alive: boolean = false; let alive: boolean = false;
if(process.platform ==='win32'){ if(process.platform ==='win32'){
...@@ -439,7 +439,7 @@ async function isAlive(pid:any): Promise<boolean>{ ...@@ -439,7 +439,7 @@ async function isAlive(pid:any): Promise<boolean>{
/** /**
* kill process * kill process
*/ */
async function killPid(pid:any): Promise<void>{ async function killPid(pid:any): Promise<void> {
let deferred : Deferred<void> = new Deferred<void>(); let deferred : Deferred<void> = new Deferred<void>();
try { try {
if (process.platform === "win32") { if (process.platform === "win32") {
...@@ -455,7 +455,7 @@ async function killPid(pid:any): Promise<void>{ ...@@ -455,7 +455,7 @@ async function killPid(pid:any): Promise<void>{
return deferred.promise; return deferred.promise;
} }
function getNewLine(): string{ function getNewLine(): string {
if (process.platform === "win32") { if (process.platform === "win32") {
return "\r\n"; return "\r\n";
} }
......
...@@ -24,7 +24,8 @@ import { Deferred } from 'ts-deferred'; ...@@ -24,7 +24,8 @@ import { Deferred } from 'ts-deferred';
import * as component from '../common/component'; import * as component from '../common/component';
import { Database, DataStore, MetricData, MetricDataRecord, MetricType, import { Database, DataStore, MetricData, MetricDataRecord, MetricType,
TrialJobEvent, TrialJobEventRecord, TrialJobInfo } from '../common/datastore'; TrialJobEvent, TrialJobEventRecord, TrialJobInfo, HyperParameterFormat,
ExportedDataFormat } from '../common/datastore';
import { NNIError } from '../common/errors'; import { NNIError } from '../common/errors';
import { getExperimentId, isNewExperiment } from '../common/experimentStartupInfo'; import { getExperimentId, isNewExperiment } from '../common/experimentStartupInfo';
import { getLogger, Logger } from '../common/log'; import { getLogger, Logger } from '../common/log';
...@@ -171,6 +172,61 @@ class NNIDataStore implements DataStore { ...@@ -171,6 +172,61 @@ class NNIDataStore implements DataStore {
return this.db.queryMetricData(trialJobId, metricType); return this.db.queryMetricData(trialJobId, metricType);
} }
public async exportTrialHpConfigs(): Promise<string> {
const jobs: TrialJobInfo[] = await this.listTrialJobs();
let exportedData: ExportedDataFormat[] = [];
for (const job of jobs) {
if (job.hyperParameters && job.finalMetricData) {
if (job.hyperParameters.length === 1 && job.finalMetricData.length === 1) {
// optimization for non-multi-phase case
const parameters: HyperParameterFormat = <HyperParameterFormat>JSON.parse(job.hyperParameters[0]);
const oneEntry: ExportedDataFormat = {
parameter: parameters.parameters,
value: JSON.parse(job.finalMetricData[0].data),
id: job.id
};
exportedData.push(oneEntry);
} else {
let paraMap: Map<number, Object> = new Map();
let metricMap: Map<number, Object> = new Map();
for (const eachPara of job.hyperParameters) {
const parameters: HyperParameterFormat = <HyperParameterFormat>JSON.parse(eachPara);
paraMap.set(parameters.parameter_id, parameters.parameters);
}
for (const eachMetric of job.finalMetricData) {
const value: Object = JSON.parse(eachMetric.data);
metricMap.set(Number(eachMetric.parameterId), value);
}
paraMap.forEach((value: Object, key: number) => {
const metricValue: Object | undefined = metricMap.get(key);
if (metricValue) {
const oneEntry: ExportedDataFormat = {
parameter: value,
value: metricValue,
id: job.id
};
exportedData.push(oneEntry);
}
});
}
}
}
return JSON.stringify(exportedData);
}
public async getImportedData(): Promise<string[]> {
let importedData: string[] = [];
const importDataEvents: TrialJobEventRecord[] = await this.db.queryTrialJobEvent(undefined, 'IMPORT_DATA');
for (const event of importDataEvents) {
if (event.data) {
importedData.push(event.data);
}
}
return importedData;
}
private async queryTrialJobs(status?: TrialJobStatus, trialJobId?: string): Promise<TrialJobInfo[]> { private async queryTrialJobs(status?: TrialJobStatus, trialJobId?: string): Promise<TrialJobInfo[]> {
const result: TrialJobInfo[] = []; const result: TrialJobInfo[] = [];
const trialJobEvents: TrialJobEventRecord[] = await this.db.queryTrialJobEvent(trialJobId); const trialJobEvents: TrialJobEventRecord[] = await this.db.queryTrialJobEvent(trialJobId);
......
...@@ -58,7 +58,10 @@ class NNIManager implements Manager { ...@@ -58,7 +58,10 @@ class NNIManager implements Manager {
private status: NNIManagerStatus; private status: NNIManagerStatus;
private waitingTrials: string[]; private waitingTrials: string[];
private trialJobs: Map<string, TrialJobDetail>; private trialJobs: Map<string, TrialJobDetail>;
private trialDataForTuner: string;
private trialJobMetricListener: (metric: TrialJobMetric) => void;
constructor() { constructor() {
this.currSubmittedTrialNum = 0; this.currSubmittedTrialNum = 0;
this.trialConcurrencyChange = 0; this.trialConcurrencyChange = 0;
...@@ -68,6 +71,7 @@ class NNIManager implements Manager { ...@@ -68,6 +71,7 @@ class NNIManager implements Manager {
this.dispatcherPid = 0; this.dispatcherPid = 0;
this.waitingTrials = []; this.waitingTrials = [];
this.trialJobs = new Map<string, TrialJobDetail>(); this.trialJobs = new Map<string, TrialJobDetail>();
this.trialDataForTuner = '';
this.log = getLogger(); this.log = getLogger();
this.dataStore = component.get(DataStore); this.dataStore = component.get(DataStore);
...@@ -76,6 +80,11 @@ class NNIManager implements Manager { ...@@ -76,6 +80,11 @@ class NNIManager implements Manager {
status: 'INITIALIZED', status: 'INITIALIZED',
errors: [] errors: []
}; };
this.trialJobMetricListener = (metric: TrialJobMetric) => {
this.onTrialJobMetrics(metric).catch((err: Error) => {
this.criticalError(NNIError.FromError(err, 'Job metrics error: '));
});
};
} }
public updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void> { public updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void> {
...@@ -110,6 +119,10 @@ class NNIManager implements Manager { ...@@ -110,6 +119,10 @@ class NNIManager implements Manager {
return this.dataStore.storeTrialJobEvent('IMPORT_DATA', '', data); return this.dataStore.storeTrialJobEvent('IMPORT_DATA', '', data);
} }
public async exportData(): Promise<string> {
return this.dataStore.exportTrialHpConfigs();
}
public addCustomizedTrialJob(hyperParams: string): Promise<void> { public addCustomizedTrialJob(hyperParams: string): Promise<void> {
if (this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) { if (this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) {
return Promise.reject( return Promise.reject(
...@@ -206,6 +219,16 @@ class NNIManager implements Manager { ...@@ -206,6 +219,16 @@ class NNIManager implements Manager {
.filter((job: TrialJobInfo) => job.status === 'WAITING' || job.status === 'RUNNING') .filter((job: TrialJobInfo) => job.status === 'WAITING' || job.status === 'RUNNING')
.map((job: TrialJobInfo) => this.dataStore.storeTrialJobEvent('FAILED', job.id))); .map((job: TrialJobInfo) => this.dataStore.storeTrialJobEvent('FAILED', job.id)));
// Collect generated trials and imported trials
const finishedTrialData: string = await this.exportData();
const importedData: string[] = await this.dataStore.getImportedData();
let trialData: Object[] = JSON.parse(finishedTrialData);
for (const oneImportedData of importedData) {
// do not deduplicate
trialData = trialData.concat(<Object[]>JSON.parse(oneImportedData));
}
this.trialDataForTuner = JSON.stringify(trialData);
if (this.experimentProfile.execDuration < this.experimentProfile.params.maxExecDuration && if (this.experimentProfile.execDuration < this.experimentProfile.params.maxExecDuration &&
this.currSubmittedTrialNum < this.experimentProfile.params.maxTrialNum && this.currSubmittedTrialNum < this.experimentProfile.params.maxTrialNum &&
this.experimentProfile.endTime) { this.experimentProfile.endTime) {
...@@ -342,6 +365,7 @@ class NNIManager implements Manager { ...@@ -342,6 +365,7 @@ class NNIManager implements Manager {
if (this.dispatcher === undefined) { if (this.dispatcher === undefined) {
throw new Error('Error: tuner has not been setup'); throw new Error('Error: tuner has not been setup');
} }
this.trainingService.removeTrialJobMetricListener(this.trialJobMetricListener);
this.dispatcher.sendCommand(TERMINATE); this.dispatcher.sendCommand(TERMINATE);
let tunerAlive: boolean = true; let tunerAlive: boolean = true;
// gracefully terminate tuner and assessor here, wait at most 30 seconds. // gracefully terminate tuner and assessor here, wait at most 30 seconds.
...@@ -589,11 +613,7 @@ class NNIManager implements Manager { ...@@ -589,11 +613,7 @@ class NNIManager implements Manager {
if (this.dispatcher === undefined) { if (this.dispatcher === undefined) {
throw new Error('Error: tuner or job maintainer have not been setup'); throw new Error('Error: tuner or job maintainer have not been setup');
} }
this.trainingService.addTrialJobMetricListener((metric: TrialJobMetric) => { this.trainingService.addTrialJobMetricListener(this.trialJobMetricListener);
this.onTrialJobMetrics(metric).catch((err: Error) => {
this.criticalError(NNIError.FromError(err, 'Job metrics error: '));
});
});
this.dispatcher.onCommand((commandType: string, content: string) => { this.dispatcher.onCommand((commandType: string, content: string) => {
this.onTunerCommand(commandType, content).catch((err: Error) => { this.onTunerCommand(commandType, content).catch((err: Error) => {
...@@ -644,6 +664,12 @@ class NNIManager implements Manager { ...@@ -644,6 +664,12 @@ class NNIManager implements Manager {
switch (commandType) { switch (commandType) {
case INITIALIZED: case INITIALIZED:
// Tuner is intialized, search space is set, request tuner to generate hyper parameters // Tuner is intialized, search space is set, request tuner to generate hyper parameters
if (this.trialDataForTuner.length > 0) {
if (this.dispatcher === undefined) {
throw new Error('Dispatcher error: tuner has not been setup');
}
this.dispatcher.sendCommand(IMPORT_DATA, this.trialDataForTuner);
}
this.requestTrialJobs(this.experimentProfile.params.trialConcurrency); this.requestTrialJobs(this.experimentProfile.params.trialConcurrency);
break; break;
case NEW_TRIAL_JOB: case NEW_TRIAL_JOB:
......
...@@ -210,6 +210,16 @@ class MockedDataStore implements DataStore { ...@@ -210,6 +210,16 @@ class MockedDataStore implements DataStore {
return result; return result;
} }
async exportTrialHpConfigs(): Promise<string> {
const ret: string = '';
return Promise.resolve(ret);
}
async getImportedData(): Promise<string[]> {
const ret: string[] = [];
return Promise.resolve(ret);
}
public getTrialJob(trialJobId: string): Promise<TrialJobInfo> { public getTrialJob(trialJobId: string): Promise<TrialJobInfo> {
throw new Error("Method not implemented."); throw new Error("Method not implemented.");
} }
......
...@@ -72,6 +72,7 @@ class NNIRestHandler { ...@@ -72,6 +72,7 @@ class NNIRestHandler {
this.addTrialJob(router); this.addTrialJob(router);
this.cancelTrialJob(router); this.cancelTrialJob(router);
this.getMetricData(router); this.getMetricData(router);
this.exportData(router);
// Express-joi-validator configuration // Express-joi-validator configuration
router.use((err: any, req: Request, res: Response, next: any) => { router.use((err: any, req: Request, res: Response, next: any) => {
...@@ -261,6 +262,16 @@ class NNIRestHandler { ...@@ -261,6 +262,16 @@ class NNIRestHandler {
}); });
} }
private exportData(router: Router): void {
router.get('/export-data', (req: Request, res: Response) => {
this.nniManager.exportData().then((exportedData: string) => {
res.send(exportedData);
}).catch((err: Error) => {
this.handle_error(err, res);
});
});
}
private setErrorPathForFailedJob(jobInfo: TrialJobInfo): TrialJobInfo { private setErrorPathForFailedJob(jobInfo: TrialJobInfo): TrialJobInfo {
if (jobInfo === undefined || jobInfo.status !== 'FAILED' || jobInfo.logPath === undefined) { if (jobInfo === undefined || jobInfo.status !== 'FAILED' || jobInfo.logPath === undefined) {
return jobInfo; return jobInfo;
......
...@@ -31,10 +31,14 @@ export namespace ValidationSchemas { ...@@ -31,10 +31,14 @@ export namespace ValidationSchemas {
passwd: joi.string(), passwd: joi.string(),
sshKeyPath: joi.string(), sshKeyPath: joi.string(),
passphrase: joi.string(), passphrase: joi.string(),
gpuIndices: joi.string() gpuIndices: joi.string(),
maxTrialNumPerGpu: joi.number(),
useActiveGpu: joi.boolean()
})), })),
local_config: joi.object({ local_config: joi.object({
gpuIndices: joi.string() gpuIndices: joi.string(),
maxTrialNumPerGpu: joi.number(),
useActiveGpu: joi.boolean()
}), }),
trial_config: joi.object({ trial_config: joi.object({
image: joi.string().min(1), image: joi.string().min(1),
......
...@@ -49,6 +49,10 @@ export class MockedNNIManager extends Manager { ...@@ -49,6 +49,10 @@ export class MockedNNIManager extends Manager {
public importData(data: string): Promise<void> { public importData(data: string): Promise<void> {
return Promise.resolve(); return Promise.resolve();
} }
public async exportData(): Promise<string> {
const ret: string = '';
return Promise.resolve(ret);
}
public getTrialJobStatistics(): Promise<TrialJobStatistics[]> { public getTrialJobStatistics(): Promise<TrialJobStatistics[]> {
const deferred: Deferred<TrialJobStatistics[]> = new Deferred<TrialJobStatistics[]>(); const deferred: Deferred<TrialJobStatistics[]> = new Deferred<TrialJobStatistics[]>();
deferred.resolve([{ deferred.resolve([{
......
...@@ -24,7 +24,10 @@ import { getLogger } from "common/log"; ...@@ -24,7 +24,10 @@ import { getLogger } from "common/log";
import { countFilesRecursively } from '../../common/utils' import { countFilesRecursively } from '../../common/utils'
import * as cpp from 'child-process-promise'; import * as cpp from 'child-process-promise';
import * as cp from 'child_process'; import * as cp from 'child_process';
import { GPU_INFO_COLLECTOR_FORMAT_LINUX, GPU_INFO_COLLECTOR_FORMAT_WINDOWS } from './gpuData' import * as os from 'os';
import * as fs from 'fs';
import { getNewLine } from '../../common/utils';
import { GPU_INFO_COLLECTOR_FORMAT_LINUX, GPU_INFO_COLLECTOR_FORMAT_WINDOWS } from './gpuData';
import * as path from 'path'; import * as path from 'path';
import { String } from 'typescript-string-operations'; import { String } from 'typescript-string-operations';
import { file } from "../../node_modules/@types/tmp"; import { file } from "../../node_modules/@types/tmp";
...@@ -66,6 +69,20 @@ export async function execMkdir(directory: string): Promise<void> { ...@@ -66,6 +69,20 @@ export async function execMkdir(directory: string): Promise<void> {
return Promise.resolve(); return Promise.resolve();
} }
/**
* copy files to the directory
* @param source
* @param destination
*/
export async function execCopydir(source: string, destination: string): Promise<void> {
if (process.platform === 'win32') {
await cpp.exec(`powershell.exe Copy-Item ${source} -Destination ${destination} -Recurse`);
} else {
await cpp.exec(`cp -r ${source} ${destination}`);
}
return Promise.resolve();
}
/** /**
* crete a new file * crete a new file
* @param filename * @param filename
...@@ -91,8 +108,6 @@ export function execScript(filePath: string): cp.ChildProcess { ...@@ -91,8 +108,6 @@ export function execScript(filePath: string): cp.ChildProcess {
} }
} }
/** /**
* output the last line of a file * output the last line of a file
* @param filePath * @param filePath
...@@ -111,9 +126,9 @@ export async function execTail(filePath: string): Promise<cpp.childProcessPromis ...@@ -111,9 +126,9 @@ export async function execTail(filePath: string): Promise<cpp.childProcessPromis
* delete a directory * delete a directory
* @param directory * @param directory
*/ */
export async function execRemove(directory: string): Promise<void>{ export async function execRemove(directory: string): Promise<void> {
if (process.platform === 'win32') { if (process.platform === 'win32') {
await cpp.exec(`powershell.exe Remove-Item ${directory}`); await cpp.exec(`powershell.exe Remove-Item ${directory} -Recurse -Force`);
} else { } else {
await cpp.exec(`rm -rf ${directory}`); await cpp.exec(`rm -rf ${directory}`);
} }
...@@ -124,7 +139,7 @@ export async function execRemove(directory: string): Promise<void>{ ...@@ -124,7 +139,7 @@ export async function execRemove(directory: string): Promise<void>{
* kill a process * kill a process
* @param directory * @param directory
*/ */
export async function execKill(pid: string): Promise<void>{ export async function execKill(pid: string): Promise<void> {
if (process.platform === 'win32') { if (process.platform === 'win32') {
await cpp.exec(`cmd /c taskkill /PID ${pid} /T /F`); await cpp.exec(`cmd /c taskkill /PID ${pid} /T /F`);
} else { } else {
...@@ -138,7 +153,7 @@ export async function execKill(pid: string): Promise<void>{ ...@@ -138,7 +153,7 @@ export async function execKill(pid: string): Promise<void>{
* @param variable * @param variable
* @returns command string * @returns command string
*/ */
export function setEnvironmentVariable(variable: { key: string; value: string }): string{ export function setEnvironmentVariable(variable: { key: string; value: string }): string {
if (process.platform === 'win32') { if (process.platform === 'win32') {
return `$env:${variable.key}="${variable.value}"`; return `$env:${variable.key}="${variable.value}"`;
} }
...@@ -147,6 +162,32 @@ export function setEnvironmentVariable(variable: { key: string; value: string }) ...@@ -147,6 +162,32 @@ export function setEnvironmentVariable(variable: { key: string; value: string })
} }
} }
/**
* Compress files in directory to tar file
* @param source_path
* @param tar_path
*/
export async function tarAdd(tar_path: string, source_path: string): Promise<void> {
if (process.platform === 'win32') {
tar_path = tar_path.split('\\').join('\\\\');
source_path = source_path.split('\\').join('\\\\');
let script: string[] = [];
script.push(
`import os`,
`import tarfile`,
String.Format(`tar = tarfile.open("{0}","w:gz")\r\nfor root,dir,files in os.walk("{1}"):`, tar_path, source_path),
` for file in files:`,
` fullpath = os.path.join(root,file)`,
` tar.add(fullpath, arcname=file)`,
`tar.close()`);
await fs.promises.writeFile(path.join(os.tmpdir(), 'tar.py'), script.join(getNewLine()), { encoding: 'utf8', mode: 0o777 });
const tarScript: string = path.join(os.tmpdir(), 'tar.py');
await cpp.exec(`python ${tarScript}`);
} else {
await cpp.exec(`tar -czf ${tar_path} -C ${source_path} .`);
}
return Promise.resolve();
}
/** /**
* generate script file name * generate script file name
......
...@@ -71,14 +71,15 @@ class GPUScheduler { ...@@ -71,14 +71,15 @@ class GPUScheduler {
execScript(gpuMetricsCollectorScriptPath) execScript(gpuMetricsCollectorScriptPath)
} }
public getAvailableGPUIndices(): number[] { public getAvailableGPUIndices(useActiveGpu: boolean, occupiedGpuIndexNumMap: Map<number, number>): number[] {
if (this.gpuSummary !== undefined) { if (this.gpuSummary !== undefined) {
if(process.platform === 'win32') { if(process.platform === 'win32' || useActiveGpu) {
return this.gpuSummary.gpuInfos.map((info: GPUInfo) => info.index); return this.gpuSummary.gpuInfos.map((info: GPUInfo) => info.index);
} }
else{ else{
return this.gpuSummary.gpuInfos.filter((info: GPUInfo) => info.activeProcessNum === 0) return this.gpuSummary.gpuInfos.filter((info: GPUInfo) =>
.map((info: GPUInfo) => info.index); occupiedGpuIndexNumMap.get(info.index) === undefined && info.activeProcessNum === 0 ||
occupiedGpuIndexNumMap.get(info.index) !== undefined).map((info: GPUInfo) => info.index);
} }
} }
......
...@@ -97,11 +97,19 @@ class LocalTrialJobDetail implements TrialJobDetail { ...@@ -97,11 +97,19 @@ class LocalTrialJobDetail implements TrialJobDetail {
* Local training service config * Local training service config
*/ */
class LocalConfig { class LocalConfig {
public maxTrialNumPerGpu?: number;
public gpuIndices?: string; public gpuIndices?: string;
constructor(gpuIndices?: string) { public useActiveGpu?: boolean;
constructor(gpuIndices?: string, maxTrialNumPerGpu?: number, useActiveGpu?: boolean) {
if (gpuIndices !== undefined) { if (gpuIndices !== undefined) {
this.gpuIndices = gpuIndices; this.gpuIndices = gpuIndices;
} }
if (maxTrialNumPerGpu !== undefined) {
this.maxTrialNumPerGpu = maxTrialNumPerGpu;
}
if (useActiveGpu !== undefined) {
this.useActiveGpu = useActiveGpu;
}
} }
} }
...@@ -117,13 +125,15 @@ class LocalTrainingService implements TrainingService { ...@@ -117,13 +125,15 @@ class LocalTrainingService implements TrainingService {
private rootDir!: string; private rootDir!: string;
private trialSequenceId: number; private trialSequenceId: number;
private gpuScheduler!: GPUScheduler; private gpuScheduler!: GPUScheduler;
private occupiedGpuIndices: Set<number>; private occupiedGpuIndexNumMap: Map<number, number>;
private designatedGpuIndices!: Set<number>; private designatedGpuIndices!: Set<number>;
private log: Logger; private log: Logger;
private localTrailConfig?: TrialConfig; private localTrailConfig?: TrialConfig;
private localConfig?: LocalConfig; private localConfig?: LocalConfig;
private isMultiPhase: boolean = false; private isMultiPhase: boolean;
private jobStreamMap: Map<string, ts.Stream>; private jobStreamMap: Map<string, ts.Stream>;
private maxTrialNumPerGpu: number;
private useActiveGpu: boolean;
constructor() { constructor() {
this.eventEmitter = new EventEmitter(); this.eventEmitter = new EventEmitter();
...@@ -135,7 +145,10 @@ class LocalTrainingService implements TrainingService { ...@@ -135,7 +145,10 @@ class LocalTrainingService implements TrainingService {
this.trialSequenceId = -1; this.trialSequenceId = -1;
this.jobStreamMap = new Map<string, ts.Stream>(); this.jobStreamMap = new Map<string, ts.Stream>();
this.log.info('Construct local machine training service.'); this.log.info('Construct local machine training service.');
this.occupiedGpuIndices = new Set<number>(); this.occupiedGpuIndexNumMap = new Map<number, number>();
this.maxTrialNumPerGpu = 1;
this.useActiveGpu = false;
this.isMultiPhase = false;
} }
public async run(): Promise<void> { public async run(): Promise<void> {
...@@ -304,6 +317,13 @@ class LocalTrainingService implements TrainingService { ...@@ -304,6 +317,13 @@ class LocalTrainingService implements TrainingService {
throw new Error('gpuIndices can not be empty if specified.'); throw new Error('gpuIndices can not be empty if specified.');
} }
} }
if (this.localConfig.maxTrialNumPerGpu !== undefined) {
this.maxTrialNumPerGpu = this.localConfig.maxTrialNumPerGpu;
}
if (this.localConfig.useActiveGpu !== undefined) {
this.useActiveGpu = this.localConfig.useActiveGpu;
}
break; break;
case TrialConfigMetadataKey.MULTI_PHASE: case TrialConfigMetadataKey.MULTI_PHASE:
this.isMultiPhase = (value === 'true' || value === 'True'); this.isMultiPhase = (value === 'true' || value === 'True');
...@@ -356,7 +376,14 @@ class LocalTrainingService implements TrainingService { ...@@ -356,7 +376,14 @@ class LocalTrainingService implements TrainingService {
if (trialJob.gpuIndices !== undefined && trialJob.gpuIndices.length > 0 && this.gpuScheduler !== undefined) { if (trialJob.gpuIndices !== undefined && trialJob.gpuIndices.length > 0 && this.gpuScheduler !== undefined) {
if (oldStatus === 'RUNNING' && trialJob.status !== 'RUNNING') { if (oldStatus === 'RUNNING' && trialJob.status !== 'RUNNING') {
for (const index of trialJob.gpuIndices) { for (const index of trialJob.gpuIndices) {
this.occupiedGpuIndices.delete(index); let num: number | undefined = this.occupiedGpuIndexNumMap.get(index);
if(num === undefined) {
throw new Error(`gpu resource schedule error`);
} else if(num === 1) {
this.occupiedGpuIndexNumMap.delete(index);
} else {
this.occupiedGpuIndexNumMap.set(index, num - 1)
}
} }
} }
} }
...@@ -396,8 +423,14 @@ class LocalTrainingService implements TrainingService { ...@@ -396,8 +423,14 @@ class LocalTrainingService implements TrainingService {
return [true, resource]; return [true, resource];
} }
let selectedGPUIndices: number[] = this.gpuScheduler.getAvailableGPUIndices() let selectedGPUIndices: number[] = [];
.filter((index: number) => !this.occupiedGpuIndices.has(index)); let availableGpuIndices: number[] = this.gpuScheduler.getAvailableGPUIndices(this.useActiveGpu, this.occupiedGpuIndexNumMap);
for(let index of availableGpuIndices) {
let num: number | undefined = this.occupiedGpuIndexNumMap.get(index);
if(num === undefined || num < this.maxTrialNumPerGpu) {
selectedGPUIndices.push(index);
}
}
if (this.designatedGpuIndices !== undefined) { if (this.designatedGpuIndices !== undefined) {
this.checkSpecifiedGpuIndices(); this.checkSpecifiedGpuIndices();
...@@ -428,7 +461,12 @@ class LocalTrainingService implements TrainingService { ...@@ -428,7 +461,12 @@ class LocalTrainingService implements TrainingService {
private occupyResource(resource: {gpuIndices: number[]}): void { private occupyResource(resource: {gpuIndices: number[]}): void {
if (this.gpuScheduler !== undefined) { if (this.gpuScheduler !== undefined) {
for (const index of resource.gpuIndices) { for (const index of resource.gpuIndices) {
this.occupiedGpuIndices.add(index); let num: number | undefined = this.occupiedGpuIndexNumMap.get(index);
if(num === undefined) {
this.occupiedGpuIndexNumMap.set(index, 1)
} else {
this.occupiedGpuIndexNumMap.set(index, num + 1)
}
} }
} }
} }
......
...@@ -23,7 +23,8 @@ import * as assert from 'assert'; ...@@ -23,7 +23,8 @@ import * as assert from 'assert';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { randomSelect } from '../../common/utils'; import { randomSelect } from '../../common/utils';
import { GPUInfo } from '../common/gpuData'; import { GPUInfo } from '../common/gpuData';
import { parseGpuIndices, RemoteMachineMeta, RemoteMachineScheduleResult, ScheduleResultType, SSHClientManager } from './remoteMachineData'; import { RemoteMachineTrialJobDetail, parseGpuIndices, RemoteMachineMeta, RemoteMachineScheduleResult, ScheduleResultType, SSHClientManager } from './remoteMachineData';
import { TrialJobDetail } from 'common/trainingService';
/** /**
* A simple GPU scheduler implementation * A simple GPU scheduler implementation
...@@ -45,7 +46,7 @@ export class GPUScheduler { ...@@ -45,7 +46,7 @@ export class GPUScheduler {
* Schedule a machine according to the constraints (requiredGPUNum) * Schedule a machine according to the constraints (requiredGPUNum)
* @param requiredGPUNum required GPU number * @param requiredGPUNum required GPU number
*/ */
public scheduleMachine(requiredGPUNum: number, trialJobId : string) : RemoteMachineScheduleResult { public scheduleMachine(requiredGPUNum: number, trialJobDetail : RemoteMachineTrialJobDetail) : RemoteMachineScheduleResult {
assert(requiredGPUNum >= 0); assert(requiredGPUNum >= 0);
const allRMs: RemoteMachineMeta[] = Array.from(this.machineSSHClientMap.keys()); const allRMs: RemoteMachineMeta[] = Array.from(this.machineSSHClientMap.keys());
assert(allRMs.length > 0); assert(allRMs.length > 0);
...@@ -66,7 +67,7 @@ export class GPUScheduler { ...@@ -66,7 +67,7 @@ export class GPUScheduler {
// Currenty the requireGPUNum parameter for all trial jobs are identical. // Currenty the requireGPUNum parameter for all trial jobs are identical.
if (requiredGPUNum > 0) { if (requiredGPUNum > 0) {
// Trial job requires GPU // Trial job requires GPU
const result: RemoteMachineScheduleResult | undefined = this.scheduleGPUHost(requiredGPUNum, trialJobId); const result: RemoteMachineScheduleResult | undefined = this.scheduleGPUHost(requiredGPUNum, trialJobDetail);
if (result !== undefined) { if (result !== undefined) {
return result; return result;
} }
...@@ -74,9 +75,9 @@ export class GPUScheduler { ...@@ -74,9 +75,9 @@ export class GPUScheduler {
// Trail job does not need GPU // Trail job does not need GPU
const allocatedRm: RemoteMachineMeta = this.selectMachine(allRMs); const allocatedRm: RemoteMachineMeta = this.selectMachine(allRMs);
return this.allocateHost(requiredGPUNum, allocatedRm, [], trialJobId); return this.allocateHost(requiredGPUNum, allocatedRm, [], trialJobDetail);
} }
this.log.warning(`Scheduler: trialJob id ${trialJobId}, no machine can be scheduled, return TMP_NO_AVAILABLE_GPU `); this.log.warning(`Scheduler: trialJob id ${trialJobDetail.id}, no machine can be scheduled, return TMP_NO_AVAILABLE_GPU `);
return { return {
resultType : ScheduleResultType.TMP_NO_AVAILABLE_GPU, resultType : ScheduleResultType.TMP_NO_AVAILABLE_GPU,
...@@ -87,21 +88,35 @@ export class GPUScheduler { ...@@ -87,21 +88,35 @@ export class GPUScheduler {
/** /**
* remove the job's gpu reversion * remove the job's gpu reversion
*/ */
public removeGpuReservation(trialJobId: string, rmMeta?: RemoteMachineMeta): void { public removeGpuReservation(trialJobId: string, trialJobMap: Map<string, RemoteMachineTrialJobDetail>): void {
// If remote machine has no GPU, gpuReservcation is not initialized, so check if it's undefined let trialJobDetail: RemoteMachineTrialJobDetail | undefined = trialJobMap.get(trialJobId);
if (rmMeta !== undefined && rmMeta.gpuReservation !== undefined) { if(trialJobDetail === undefined) {
rmMeta.gpuReservation.forEach((reserveTrialJobId : string, gpuIndex : number) => { throw new Error(`could not get trialJobDetail by id ${trialJobId}`);
if (reserveTrialJobId === trialJobId) { }
rmMeta.gpuReservation.delete(gpuIndex); if (trialJobDetail.rmMeta !== undefined &&
trialJobDetail.rmMeta.occupiedGpuIndexMap !== undefined &&
trialJobDetail.gpuIndices !== undefined &&
trialJobDetail.gpuIndices.length > 0) {
for (const gpuInfo of trialJobDetail.gpuIndices) {
let num: number | undefined = trialJobDetail.rmMeta.occupiedGpuIndexMap.get(gpuInfo.index);
if(num !== undefined) {
if(num === 1) {
trialJobDetail.rmMeta.occupiedGpuIndexMap.delete(gpuInfo.index);
} else {
trialJobDetail.rmMeta.occupiedGpuIndexMap.set(gpuInfo.index, num - 1)
}
} }
}); }
} }
trialJobDetail.gpuIndices = [];
trialJobMap.set(trialJobId, trialJobDetail);
} }
private scheduleGPUHost(requiredGPUNum: number, trialJobId: string): RemoteMachineScheduleResult | undefined { private scheduleGPUHost(requiredGPUNum: number, trialJobDetail: RemoteMachineTrialJobDetail): RemoteMachineScheduleResult | undefined {
const totalResourceMap: Map<RemoteMachineMeta, GPUInfo[]> = this.gpuResourceDetection(); const totalResourceMap: Map<RemoteMachineMeta, GPUInfo[]> = this.gpuResourceDetection();
const qualifiedRMs: RemoteMachineMeta[] = []; const qualifiedRMs: RemoteMachineMeta[] = [];
totalResourceMap.forEach((gpuInfos: GPUInfo[], rmMeta: RemoteMachineMeta) => { totalResourceMap.forEach((gpuInfos: GPUInfo[], rmMeta: RemoteMachineMeta) => {
if (gpuInfos !== undefined && gpuInfos.length >= requiredGPUNum) { if (gpuInfos !== undefined && gpuInfos.length >= requiredGPUNum) {
qualifiedRMs.push(rmMeta); qualifiedRMs.push(rmMeta);
} }
...@@ -110,7 +125,7 @@ export class GPUScheduler { ...@@ -110,7 +125,7 @@ export class GPUScheduler {
const allocatedRm: RemoteMachineMeta = this.selectMachine(qualifiedRMs); const allocatedRm: RemoteMachineMeta = this.selectMachine(qualifiedRMs);
const gpuInfos: GPUInfo[] | undefined = totalResourceMap.get(allocatedRm); const gpuInfos: GPUInfo[] | undefined = totalResourceMap.get(allocatedRm);
if (gpuInfos !== undefined) { // should always true if (gpuInfos !== undefined) { // should always true
return this.allocateHost(requiredGPUNum, allocatedRm, gpuInfos, trialJobId); return this.allocateHost(requiredGPUNum, allocatedRm, gpuInfos, trialJobDetail);
} else { } else {
assert(false, 'gpuInfos is undefined'); assert(false, 'gpuInfos is undefined');
} }
...@@ -130,9 +145,6 @@ export class GPUScheduler { ...@@ -130,9 +145,6 @@ export class GPUScheduler {
// Assgin totoal GPU count as init available GPU number // Assgin totoal GPU count as init available GPU number
if (rmMeta.gpuSummary !== undefined) { if (rmMeta.gpuSummary !== undefined) {
const availableGPUs: GPUInfo[] = []; const availableGPUs: GPUInfo[] = [];
if (rmMeta.gpuReservation === undefined) {
rmMeta.gpuReservation = new Map<number, string>();
}
const designatedGpuIndices: Set<number> | undefined = parseGpuIndices(rmMeta.gpuIndices); const designatedGpuIndices: Set<number> | undefined = parseGpuIndices(rmMeta.gpuIndices);
if (designatedGpuIndices !== undefined) { if (designatedGpuIndices !== undefined) {
for (const gpuIndex of designatedGpuIndices) { for (const gpuIndex of designatedGpuIndices) {
...@@ -145,10 +157,20 @@ export class GPUScheduler { ...@@ -145,10 +157,20 @@ export class GPUScheduler {
rmMeta.gpuSummary.gpuInfos.forEach((gpuInfo: GPUInfo) => { rmMeta.gpuSummary.gpuInfos.forEach((gpuInfo: GPUInfo) => {
// if the GPU has active process, OR be reserved by a job, // if the GPU has active process, OR be reserved by a job,
// or index not in gpuIndices configuration in machineList, // or index not in gpuIndices configuration in machineList,
// or trial number on a GPU reach max number,
// We should NOT allocate this GPU // We should NOT allocate this GPU
if (gpuInfo.activeProcessNum === 0 && !rmMeta.gpuReservation.has(gpuInfo.index) // if users set useActiveGpu, use the gpu whether there is another activeProcess
&& (designatedGpuIndices === undefined || designatedGpuIndices.has(gpuInfo.index))) { if (designatedGpuIndices === undefined || designatedGpuIndices.has(gpuInfo.index)) {
availableGPUs.push(gpuInfo); if(rmMeta.occupiedGpuIndexMap !== undefined) {
let num = rmMeta.occupiedGpuIndexMap.get(gpuInfo.index);
let maxTrialNumPerGpu: number = rmMeta.maxTrialNumPerGpu? rmMeta.maxTrialNumPerGpu: 1;
if((num === undefined && (!rmMeta.useActiveGpu && gpuInfo.activeProcessNum === 0 || rmMeta.useActiveGpu)) ||
(num !== undefined && num < maxTrialNumPerGpu)) {
availableGPUs.push(gpuInfo);
}
} else {
throw new Error(`occupiedGpuIndexMap initialize error!`);
}
} }
}); });
totalResourceMap.set(rmMeta, availableGPUs); totalResourceMap.set(rmMeta, availableGPUs);
...@@ -170,14 +192,22 @@ export class GPUScheduler { ...@@ -170,14 +192,22 @@ export class GPUScheduler {
} }
private allocateHost(requiredGPUNum: number, rmMeta: RemoteMachineMeta, private allocateHost(requiredGPUNum: number, rmMeta: RemoteMachineMeta,
gpuInfos: GPUInfo[], trialJobId: string): RemoteMachineScheduleResult { gpuInfos: GPUInfo[], trialJobDetail: RemoteMachineTrialJobDetail): RemoteMachineScheduleResult {
assert(gpuInfos.length >= requiredGPUNum); assert(gpuInfos.length >= requiredGPUNum);
const allocatedGPUs: GPUInfo[] = this.selectGPUsForTrial(gpuInfos, requiredGPUNum); const allocatedGPUs: GPUInfo[] = this.selectGPUsForTrial(gpuInfos, requiredGPUNum);
allocatedGPUs.forEach((gpuInfo: GPUInfo) => { allocatedGPUs.forEach((gpuInfo: GPUInfo) => {
rmMeta.gpuReservation.set(gpuInfo.index, trialJobId); if(rmMeta.occupiedGpuIndexMap !== undefined) {
let num = rmMeta.occupiedGpuIndexMap.get(gpuInfo.index);
if(num === undefined) {
num = 0;
}
rmMeta.occupiedGpuIndexMap.set(gpuInfo.index, num + 1);
}else {
throw new Error(`Machine ${rmMeta.ip} occupiedGpuIndexMap initialize error!`);
}
}); });
trialJobDetail.gpuIndices = allocatedGPUs;
trialJobDetail.rmMeta = rmMeta;
return { return {
resultType: ScheduleResultType.SUCCEED, resultType: ScheduleResultType.SUCCEED,
scheduleInfo: { scheduleInfo: {
......
...@@ -23,7 +23,7 @@ import * as fs from 'fs'; ...@@ -23,7 +23,7 @@ import * as fs from 'fs';
import { Client, ConnectConfig } from 'ssh2'; import { Client, ConnectConfig } from 'ssh2';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService'; import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { GPUSummary } from '../common/gpuData'; import { GPUSummary, GPUInfo } from '../common/gpuData';
/** /**
* Metadata of remote machine for configuration and statuc query * Metadata of remote machine for configuration and statuc query
...@@ -36,20 +36,23 @@ export class RemoteMachineMeta { ...@@ -36,20 +36,23 @@ export class RemoteMachineMeta {
public readonly sshKeyPath?: string; public readonly sshKeyPath?: string;
public readonly passphrase?: string; public readonly passphrase?: string;
public gpuSummary : GPUSummary | undefined; public gpuSummary : GPUSummary | undefined;
// GPU Reservation info, the key is GPU index, the value is the job id which reserves this GPU
public gpuReservation : Map<number, string>;
public readonly gpuIndices?: string; public readonly gpuIndices?: string;
public readonly maxTrialNumPerGpu?: number;
public occupiedGpuIndexMap: Map<number, number>;
public readonly useActiveGpu?: boolean = false;
constructor(ip : string, port : number, username : string, passwd : string, constructor(ip : string, port : number, username : string, passwd : string,
sshKeyPath: string, passphrase : string, gpuIndices?: string) { sshKeyPath: string, passphrase : string, gpuIndices?: string, maxTrialNumPerGpu?: number, useActiveGpu?: boolean) {
this.ip = ip; this.ip = ip;
this.port = port; this.port = port;
this.username = username; this.username = username;
this.passwd = passwd; this.passwd = passwd;
this.sshKeyPath = sshKeyPath; this.sshKeyPath = sshKeyPath;
this.passphrase = passphrase; this.passphrase = passphrase;
this.gpuReservation = new Map<number, string>();
this.gpuIndices = gpuIndices; this.gpuIndices = gpuIndices;
this.maxTrialNumPerGpu = maxTrialNumPerGpu;
this.occupiedGpuIndexMap = new Map<number, number>();
this.useActiveGpu = useActiveGpu;
} }
} }
...@@ -97,6 +100,7 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail { ...@@ -97,6 +100,7 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail {
public sequenceId: number; public sequenceId: number;
public rmMeta?: RemoteMachineMeta; public rmMeta?: RemoteMachineMeta;
public isEarlyStopped?: boolean; public isEarlyStopped?: boolean;
public gpuIndices: GPUInfo[];
constructor(id: string, status: TrialJobStatus, submitTime: number, constructor(id: string, status: TrialJobStatus, submitTime: number,
workingDirectory: string, form: JobApplicationForm, sequenceId: number) { workingDirectory: string, form: JobApplicationForm, sequenceId: number) {
...@@ -107,6 +111,7 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail { ...@@ -107,6 +111,7 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail {
this.form = form; this.form = form;
this.sequenceId = sequenceId; this.sequenceId = sequenceId;
this.tags = []; this.tags = [];
this.gpuIndices = []
} }
} }
......
...@@ -36,7 +36,7 @@ import { ObservableTimer } from '../../common/observableTimer'; ...@@ -36,7 +36,7 @@ import { ObservableTimer } from '../../common/observableTimer';
import { import {
HostJobApplicationForm, HyperParameters, JobApplicationForm, TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, NNIManagerIpConfig HostJobApplicationForm, HyperParameters, JobApplicationForm, TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, NNIManagerIpConfig
} from '../../common/trainingService'; } from '../../common/trainingService';
import { delay, generateParamFileName, getExperimentRootDir, uniqueString, getJobCancelStatus, getRemoteTmpDir,getIPV4Address } from '../../common/utils'; import { delay, generateParamFileName, getExperimentRootDir, uniqueString, getJobCancelStatus, getRemoteTmpDir,getIPV4Address, getVersion, unixPathJoin } from '../../common/utils';
import { GPUSummary } from '../common/gpuData'; import { GPUSummary } from '../common/gpuData';
import { TrialConfig } from '../common/trialConfig'; import { TrialConfig } from '../common/trialConfig';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
...@@ -48,10 +48,9 @@ import { ...@@ -48,10 +48,9 @@ import {
} from './remoteMachineData'; } from './remoteMachineData';
import { GPU_INFO_COLLECTOR_FORMAT_LINUX } from '../common/gpuData'; import { GPU_INFO_COLLECTOR_FORMAT_LINUX } from '../common/gpuData';
import { SSHClientUtility } from './sshClientUtility'; import { SSHClientUtility } from './sshClientUtility';
import { validateCodeDir } from '../common/util'; import { validateCodeDir, execRemove, execMkdir, execCopydir } from '../common/util';
import { RemoteMachineJobRestServer } from './remoteMachineJobRestServer'; import { RemoteMachineJobRestServer } from './remoteMachineJobRestServer';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData'; import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { mkDirP, getVersion } from '../../common/utils';
/** /**
* Training Service implementation for Remote Machine (Linux) * Training Service implementation for Remote Machine (Linux)
...@@ -234,7 +233,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -234,7 +233,7 @@ class RemoteMachineTrainingService implements TrainingService {
} else if (form.jobType === 'TRIAL') { } else if (form.jobType === 'TRIAL') {
// Generate trial job id(random) // Generate trial job id(random)
const trialJobId: string = uniqueString(5); const trialJobId: string = uniqueString(5);
const trialWorkingFolder: string = path.join(this.remoteExpRootDir, 'trials', trialJobId); const trialWorkingFolder: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJobId);
const trialJobDetail: RemoteMachineTrialJobDetail = new RemoteMachineTrialJobDetail( const trialJobDetail: RemoteMachineTrialJobDetail = new RemoteMachineTrialJobDetail(
trialJobId, trialJobId,
...@@ -283,7 +282,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -283,7 +282,7 @@ class RemoteMachineTrainingService implements TrainingService {
private updateGpuReservation() { private updateGpuReservation() {
for (const [key, value] of this.trialJobsMap) { for (const [key, value] of this.trialJobsMap) {
if(!['WAITING', 'RUNNING'].includes(value.status)) { if(!['WAITING', 'RUNNING'].includes(value.status)) {
this.gpuScheduler.removeGpuReservation(value.id, value.rmMeta); this.gpuScheduler.removeGpuReservation(key, this.trialJobsMap);
} }
}; };
} }
...@@ -354,7 +353,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -354,7 +353,7 @@ class RemoteMachineTrainingService implements TrainingService {
case TrialConfigMetadataKey.MACHINE_LIST: case TrialConfigMetadataKey.MACHINE_LIST:
await this.setupConnections(value); await this.setupConnections(value);
//remove local temp files //remove local temp files
await cpp.exec(`rm -rf ${this.getLocalGpuMetricCollectorDir()}`); await execRemove(this.getLocalGpuMetricCollectorDir());
break; break;
case TrialConfigMetadataKey.TRIAL_CONFIG: case TrialConfigMetadataKey.TRIAL_CONFIG:
const remoteMachineTrailConfig: TrialConfig = <TrialConfig>JSON.parse(value); const remoteMachineTrailConfig: TrialConfig = <TrialConfig>JSON.parse(value);
...@@ -417,7 +416,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -417,7 +416,7 @@ class RemoteMachineTrainingService implements TrainingService {
private async cleanupConnections(): Promise<void> { private async cleanupConnections(): Promise<void> {
try{ try{
for (const [rmMeta, sshClientManager] of this.machineSSHClientMap.entries()) { for (const [rmMeta, sshClientManager] of this.machineSSHClientMap.entries()) {
let jobpidPath: string = path.join(this.getRemoteScriptsPath(rmMeta.username), 'pid'); let jobpidPath: string = unixPathJoin(this.getRemoteScriptsPath(rmMeta.username), 'pid');
let client: Client | undefined = sshClientManager.getFirstSSHClient(); let client: Client | undefined = sshClientManager.getFirstSSHClient();
if(client) { if(client) {
await SSHClientUtility.remoteExeCommand(`pkill -P \`cat ${jobpidPath}\``, client); await SSHClientUtility.remoteExeCommand(`pkill -P \`cat ${jobpidPath}\``, client);
...@@ -438,7 +437,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -438,7 +437,7 @@ class RemoteMachineTrainingService implements TrainingService {
*/ */
private getLocalGpuMetricCollectorDir(): string { private getLocalGpuMetricCollectorDir(): string {
let userName: string = path.basename(os.homedir()); //get current user name of os let userName: string = path.basename(os.homedir()); //get current user name of os
return `${os.tmpdir()}/${userName}/nni/scripts/`; return path.join(os.tmpdir(), userName, 'nni', 'scripts');
} }
/** /**
...@@ -447,14 +446,14 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -447,14 +446,14 @@ class RemoteMachineTrainingService implements TrainingService {
*/ */
private async generateGpuMetricsCollectorScript(userName: string): Promise<void> { private async generateGpuMetricsCollectorScript(userName: string): Promise<void> {
let gpuMetricCollectorScriptFolder : string = this.getLocalGpuMetricCollectorDir(); let gpuMetricCollectorScriptFolder : string = this.getLocalGpuMetricCollectorDir();
await cpp.exec(`mkdir -p ${path.join(gpuMetricCollectorScriptFolder, userName)}`); await execMkdir(path.join(gpuMetricCollectorScriptFolder, userName));
//generate gpu_metrics_collector.sh //generate gpu_metrics_collector.sh
let gpuMetricsCollectorScriptPath: string = path.join(gpuMetricCollectorScriptFolder, userName, 'gpu_metrics_collector.sh'); let gpuMetricsCollectorScriptPath: string = path.join(gpuMetricCollectorScriptFolder, userName, 'gpu_metrics_collector.sh');
const remoteGPUScriptsDir: string = this.getRemoteScriptsPath(userName); // This directory is used to store gpu_metrics and pid created by script const remoteGPUScriptsDir: string = this.getRemoteScriptsPath(userName); // This directory is used to store gpu_metrics and pid created by script
const gpuMetricsCollectorScriptContent: string = String.Format( const gpuMetricsCollectorScriptContent: string = String.Format(
GPU_INFO_COLLECTOR_FORMAT_LINUX, GPU_INFO_COLLECTOR_FORMAT_LINUX,
remoteGPUScriptsDir, remoteGPUScriptsDir,
path.join(remoteGPUScriptsDir, 'pid'), unixPathJoin(remoteGPUScriptsDir, 'pid'),
); );
await fs.promises.writeFile(gpuMetricsCollectorScriptPath, gpuMetricsCollectorScriptContent, { encoding: 'utf8' }); await fs.promises.writeFile(gpuMetricsCollectorScriptPath, gpuMetricsCollectorScriptContent, { encoding: 'utf8' });
} }
...@@ -481,7 +480,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -481,7 +480,7 @@ class RemoteMachineTrainingService implements TrainingService {
private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, conn: Client): Promise<void> { private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, conn: Client): Promise<void> {
// Create root working directory after ssh connection is ready // Create root working directory after ssh connection is ready
await this.generateGpuMetricsCollectorScript(rmMeta.username); //generate gpu script in local machine first, will copy to remote machine later await this.generateGpuMetricsCollectorScript(rmMeta.username); //generate gpu script in local machine first, will copy to remote machine later
const nniRootDir: string = `${os.tmpdir()}/nni`; const nniRootDir: string = unixPathJoin(getRemoteTmpDir(this.remoteOS), 'nni');
await SSHClientUtility.remoteExeCommand(`mkdir -p ${this.remoteExpRootDir}`, conn); await SSHClientUtility.remoteExeCommand(`mkdir -p ${this.remoteExpRootDir}`, conn);
// Copy NNI scripts to remote expeirment working directory // Copy NNI scripts to remote expeirment working directory
...@@ -490,15 +489,15 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -490,15 +489,15 @@ class RemoteMachineTrainingService implements TrainingService {
await SSHClientUtility.remoteExeCommand(`mkdir -p ${remoteGpuScriptCollectorDir}`, conn); await SSHClientUtility.remoteExeCommand(`mkdir -p ${remoteGpuScriptCollectorDir}`, conn);
await SSHClientUtility.remoteExeCommand(`chmod 777 ${nniRootDir} ${nniRootDir}/* ${nniRootDir}/scripts/*`, conn); await SSHClientUtility.remoteExeCommand(`chmod 777 ${nniRootDir} ${nniRootDir}/* ${nniRootDir}/scripts/*`, conn);
//copy gpu_metrics_collector.sh to remote //copy gpu_metrics_collector.sh to remote
await SSHClientUtility.copyFileToRemote(path.join(localGpuScriptCollectorDir, rmMeta.username, 'gpu_metrics_collector.sh'), path.join(remoteGpuScriptCollectorDir, 'gpu_metrics_collector.sh'), conn); await SSHClientUtility.copyFileToRemote(path.join(localGpuScriptCollectorDir, rmMeta.username, 'gpu_metrics_collector.sh'), unixPathJoin(remoteGpuScriptCollectorDir, 'gpu_metrics_collector.sh'), conn);
//Begin to execute gpu_metrics_collection scripts //Begin to execute gpu_metrics_collection scripts
SSHClientUtility.remoteExeCommand(`bash ${path.join(remoteGpuScriptCollectorDir, 'gpu_metrics_collector.sh')}`, conn); SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(remoteGpuScriptCollectorDir, 'gpu_metrics_collector.sh')}`, conn);
this.timer.subscribe( this.timer.subscribe(
async (tick: number) => { async (tick: number) => {
const cmdresult: RemoteCommandResult = await SSHClientUtility.remoteExeCommand( const cmdresult: RemoteCommandResult = await SSHClientUtility.remoteExeCommand(
`tail -n 1 ${path.join(remoteGpuScriptCollectorDir, 'gpu_metrics')}`, conn); `tail -n 1 ${unixPathJoin(remoteGpuScriptCollectorDir, 'gpu_metrics')}`, conn);
if (cmdresult && cmdresult.stdout) { if (cmdresult && cmdresult.stdout) {
rmMeta.gpuSummary = <GPUSummary>JSON.parse(cmdresult.stdout); rmMeta.gpuSummary = <GPUSummary>JSON.parse(cmdresult.stdout);
} }
...@@ -522,7 +521,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -522,7 +521,7 @@ class RemoteMachineTrainingService implements TrainingService {
return deferred.promise; return deferred.promise;
} }
// get an ssh client from scheduler // get an ssh client from scheduler
const rmScheduleResult: RemoteMachineScheduleResult = this.gpuScheduler.scheduleMachine(this.trialConfig.gpuNum, trialJobId); const rmScheduleResult: RemoteMachineScheduleResult = this.gpuScheduler.scheduleMachine(this.trialConfig.gpuNum, trialJobDetail);
if (rmScheduleResult.resultType === ScheduleResultType.REQUIRE_EXCEED_TOTAL) { if (rmScheduleResult.resultType === ScheduleResultType.REQUIRE_EXCEED_TOTAL) {
const errorMessage : string = `Required GPU number ${this.trialConfig.gpuNum} is too large, no machine can meet`; const errorMessage : string = `Required GPU number ${this.trialConfig.gpuNum} is too large, no machine can meet`;
this.log.error(errorMessage); this.log.error(errorMessage);
...@@ -531,7 +530,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -531,7 +530,7 @@ class RemoteMachineTrainingService implements TrainingService {
} else if (rmScheduleResult.resultType === ScheduleResultType.SUCCEED } else if (rmScheduleResult.resultType === ScheduleResultType.SUCCEED
&& rmScheduleResult.scheduleInfo !== undefined) { && rmScheduleResult.scheduleInfo !== undefined) {
const rmScheduleInfo : RemoteMachineScheduleInfo = rmScheduleResult.scheduleInfo; const rmScheduleInfo : RemoteMachineScheduleInfo = rmScheduleResult.scheduleInfo;
const trialWorkingFolder: string = path.join(this.remoteExpRootDir, 'trials', trialJobId); const trialWorkingFolder: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJobId);
trialJobDetail.rmMeta = rmScheduleInfo.rmMeta; trialJobDetail.rmMeta = rmScheduleInfo.rmMeta;
...@@ -543,6 +542,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -543,6 +542,7 @@ class RemoteMachineTrainingService implements TrainingService {
trialJobDetail.url = `file://${rmScheduleInfo.rmMeta.ip}:${trialWorkingFolder}`; trialJobDetail.url = `file://${rmScheduleInfo.rmMeta.ip}:${trialWorkingFolder}`;
trialJobDetail.startTime = Date.now(); trialJobDetail.startTime = Date.now();
this.trialJobsMap.set(trialJobId, trialJobDetail);
deferred.resolve(true); deferred.resolve(true);
} else if (rmScheduleResult.resultType === ScheduleResultType.TMP_NO_AVAILABLE_GPU) { } else if (rmScheduleResult.resultType === ScheduleResultType.TMP_NO_AVAILABLE_GPU) {
this.log.info(`Right now no available GPU can be allocated for trial ${trialJobId}, will try to schedule later`); this.log.info(`Right now no available GPU can be allocated for trial ${trialJobId}, will try to schedule later`);
...@@ -575,7 +575,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -575,7 +575,7 @@ class RemoteMachineTrainingService implements TrainingService {
const trialLocalTempFolder: string = path.join(this.expRootDir, 'trials-local', trialJobId); const trialLocalTempFolder: string = path.join(this.expRootDir, 'trials-local', trialJobId);
await SSHClientUtility.remoteExeCommand(`mkdir -p ${trialWorkingFolder}`, sshClient); await SSHClientUtility.remoteExeCommand(`mkdir -p ${trialWorkingFolder}`, sshClient);
await SSHClientUtility.remoteExeCommand(`mkdir -p ${path.join(trialWorkingFolder, '.nni')}`, sshClient); await SSHClientUtility.remoteExeCommand(`mkdir -p ${unixPathJoin(trialWorkingFolder, '.nni')}`, sshClient);
// RemoteMachineRunShellFormat is the run shell format string, // RemoteMachineRunShellFormat is the run shell format string,
// See definition in remoteMachineData.ts // See definition in remoteMachineData.ts
...@@ -603,20 +603,20 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -603,20 +603,20 @@ class RemoteMachineTrainingService implements TrainingService {
getExperimentId(), getExperimentId(),
trialJobDetail.sequenceId.toString(), trialJobDetail.sequenceId.toString(),
this.isMultiPhase, this.isMultiPhase,
path.join(trialWorkingFolder, '.nni', 'jobpid'), unixPathJoin(trialWorkingFolder, '.nni', 'jobpid'),
command, command,
nniManagerIp, nniManagerIp,
this.remoteRestServerPort, this.remoteRestServerPort,
version, version,
this.logCollection, this.logCollection,
path.join(trialWorkingFolder, '.nni', 'code') unixPathJoin(trialWorkingFolder, '.nni', 'code')
) )
//create tmp trial working folder locally. //create tmp trial working folder locally.
await cpp.exec(`mkdir -p ${path.join(trialLocalTempFolder, '.nni')}`); await execMkdir(path.join(trialLocalTempFolder, '.nni'));
//create tmp trial working folder locally. //create tmp trial working folder locally.
await cpp.exec(`cp -r ${this.trialConfig.codeDir}/* ${trialLocalTempFolder}`); await execCopydir(path.join(this.trialConfig.codeDir, '*'), trialLocalTempFolder);
const installScriptContent : string = CONTAINER_INSTALL_NNI_SHELL_FORMAT; const installScriptContent : string = CONTAINER_INSTALL_NNI_SHELL_FORMAT;
// Write NNI installation file to local tmp files // Write NNI installation file to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), installScriptContent, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), installScriptContent, { encoding: 'utf8' });
...@@ -626,7 +626,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -626,7 +626,7 @@ class RemoteMachineTrainingService implements TrainingService {
// Copy files in codeDir to remote working directory // Copy files in codeDir to remote working directory
await SSHClientUtility.copyDirectoryToRemote(trialLocalTempFolder, trialWorkingFolder, sshClient, this.remoteOS); await SSHClientUtility.copyDirectoryToRemote(trialLocalTempFolder, trialWorkingFolder, sshClient, this.remoteOS);
// Execute command in remote machine // Execute command in remote machine
SSHClientUtility.remoteExeCommand(`bash ${path.join(trialWorkingFolder, 'run.sh')}`, sshClient); SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(trialWorkingFolder, 'run.sh')}`, sshClient);
} }
private async runHostJob(form: HostJobApplicationForm): Promise<TrialJobDetail> { private async runHostJob(form: HostJobApplicationForm): Promise<TrialJobDetail> {
...@@ -646,8 +646,8 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -646,8 +646,8 @@ class RemoteMachineTrainingService implements TrainingService {
); );
await fs.promises.writeFile(path.join(localDir, 'run.sh'), runScriptContent, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(localDir, 'run.sh'), runScriptContent, { encoding: 'utf8' });
await SSHClientUtility.copyFileToRemote( await SSHClientUtility.copyFileToRemote(
path.join(localDir, 'run.sh'), path.join(remoteDir, 'run.sh'), sshClient); path.join(localDir, 'run.sh'), unixPathJoin(remoteDir, 'run.sh'), sshClient);
SSHClientUtility.remoteExeCommand(`bash ${path.join(remoteDir, 'run.sh')}`, sshClient); SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(remoteDir, 'run.sh')}`, sshClient);
const jobDetail: RemoteMachineTrialJobDetail = new RemoteMachineTrialJobDetail( const jobDetail: RemoteMachineTrialJobDetail = new RemoteMachineTrialJobDetail(
jobId, 'RUNNING', Date.now(), remoteDir, form, this.generateSequenceId() jobId, 'RUNNING', Date.now(), remoteDir, form, this.generateSequenceId()
...@@ -672,7 +672,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -672,7 +672,7 @@ class RemoteMachineTrainingService implements TrainingService {
private async updateTrialJobStatus(trialJob: RemoteMachineTrialJobDetail, sshClient: Client): Promise<TrialJobDetail> { private async updateTrialJobStatus(trialJob: RemoteMachineTrialJobDetail, sshClient: Client): Promise<TrialJobDetail> {
const deferred: Deferred<TrialJobDetail> = new Deferred<TrialJobDetail>(); const deferred: Deferred<TrialJobDetail> = new Deferred<TrialJobDetail>();
const jobpidPath: string = this.getJobPidPath(trialJob.id); const jobpidPath: string = this.getJobPidPath(trialJob.id);
const trialReturnCodeFilePath: string = path.join(this.remoteExpRootDir, 'trials', trialJob.id, '.nni', 'code'); const trialReturnCodeFilePath: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJob.id, '.nni', 'code');
try { try {
const killResult: number = (await SSHClientUtility.remoteExeCommand(`kill -0 \`cat ${jobpidPath}\``, sshClient)).exitCode; const killResult: number = (await SSHClientUtility.remoteExeCommand(`kill -0 \`cat ${jobpidPath}\``, sshClient)).exitCode;
// if the process of jobpid is not alive any more // if the process of jobpid is not alive any more
...@@ -712,15 +712,15 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -712,15 +712,15 @@ class RemoteMachineTrainingService implements TrainingService {
} }
private getRemoteScriptsPath(userName: string): string { private getRemoteScriptsPath(userName: string): string {
return path.join(getRemoteTmpDir(this.remoteOS), userName, 'nni', 'scripts'); return unixPathJoin(getRemoteTmpDir(this.remoteOS), userName, 'nni', 'scripts');
} }
private getHostJobRemoteDir(jobId: string): string { private getHostJobRemoteDir(jobId: string): string {
return path.join(this.remoteExpRootDir, 'hostjobs', jobId); return unixPathJoin(this.remoteExpRootDir, 'hostjobs', jobId);
} }
private getRemoteExperimentRootDir(): string{ private getRemoteExperimentRootDir(): string{
return path.join(getRemoteTmpDir(this.remoteOS), 'nni', 'experiments', getExperimentId()); return unixPathJoin(getRemoteTmpDir(this.remoteOS), 'nni', 'experiments', getExperimentId());
} }
public get MetricsEmitter() : EventEmitter { public get MetricsEmitter() : EventEmitter {
...@@ -735,9 +735,9 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -735,9 +735,9 @@ class RemoteMachineTrainingService implements TrainingService {
let jobpidPath: string; let jobpidPath: string;
if (trialJobDetail.form.jobType === 'TRIAL') { if (trialJobDetail.form.jobType === 'TRIAL') {
jobpidPath = path.join(trialJobDetail.workingDirectory, '.nni', 'jobpid'); jobpidPath = unixPathJoin(trialJobDetail.workingDirectory, '.nni', 'jobpid');
} else if (trialJobDetail.form.jobType === 'HOST') { } else if (trialJobDetail.form.jobType === 'HOST') {
jobpidPath = path.join(this.getHostJobRemoteDir(jobId), 'jobpid'); jobpidPath = unixPathJoin(this.getHostJobRemoteDir(jobId), 'jobpid');
} else { } else {
throw new Error(`Job type not supported: ${trialJobDetail.form.jobType}`); throw new Error(`Job type not supported: ${trialJobDetail.form.jobType}`);
} }
...@@ -751,14 +751,14 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -751,14 +751,14 @@ class RemoteMachineTrainingService implements TrainingService {
throw new Error('sshClient is undefined.'); throw new Error('sshClient is undefined.');
} }
const trialWorkingFolder: string = path.join(this.remoteExpRootDir, 'trials', trialJobId); const trialWorkingFolder: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJobId);
const trialLocalTempFolder: string = path.join(this.expRootDir, 'trials-local', trialJobId); const trialLocalTempFolder: string = path.join(this.expRootDir, 'trials-local', trialJobId);
const fileName: string = generateParamFileName(hyperParameters); const fileName: string = generateParamFileName(hyperParameters);
const localFilepath: string = path.join(trialLocalTempFolder, fileName); const localFilepath: string = path.join(trialLocalTempFolder, fileName);
await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' }); await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' });
await SSHClientUtility.copyFileToRemote(localFilepath, path.join(trialWorkingFolder, fileName), sshClient); await SSHClientUtility.copyFileToRemote(localFilepath, unixPathJoin(trialWorkingFolder, fileName), sshClient);
} }
private generateSequenceId(): number { private generateSequenceId(): number {
......
...@@ -28,8 +28,9 @@ import * as stream from 'stream'; ...@@ -28,8 +28,9 @@ import * as stream from 'stream';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import { NNIError, NNIErrorNames } from '../../common/errors'; import { NNIError, NNIErrorNames } from '../../common/errors';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { uniqueString, getRemoteTmpDir } from '../../common/utils'; import { uniqueString, getRemoteTmpDir, unixPathJoin } from '../../common/utils';
import { RemoteCommandResult } from './remoteMachineData'; import { RemoteCommandResult } from './remoteMachineData';
import { execRemove, tarAdd } from '../common/util';
/** /**
* *
...@@ -47,13 +48,13 @@ export namespace SSHClientUtility { ...@@ -47,13 +48,13 @@ export namespace SSHClientUtility {
const deferred: Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
const tmpTarName: string = `${uniqueString(10)}.tar.gz`; const tmpTarName: string = `${uniqueString(10)}.tar.gz`;
const localTarPath: string = path.join(os.tmpdir(), tmpTarName); const localTarPath: string = path.join(os.tmpdir(), tmpTarName);
const remoteTarPath: string = path.join(getRemoteTmpDir(remoteOS), tmpTarName); const remoteTarPath: string = unixPathJoin(getRemoteTmpDir(remoteOS), tmpTarName);
// Compress files in local directory to experiment root directory // Compress files in local directory to experiment root directory
await cpp.exec(`tar -czf ${localTarPath} -C ${localDirectory} .`); await tarAdd(localTarPath, localDirectory);
// Copy the compressed file to remoteDirectory and delete it // Copy the compressed file to remoteDirectory and delete it
await copyFileToRemote(localTarPath, remoteTarPath, sshClient); await copyFileToRemote(localTarPath, remoteTarPath, sshClient);
await cpp.exec(`rm ${localTarPath}`); await execRemove(localTarPath);
// Decompress the remote compressed file in and delete it // Decompress the remote compressed file in and delete it
await remoteExeCommand(`tar -oxzf ${remoteTarPath} -C ${remoteDirectory}`, sshClient); await remoteExeCommand(`tar -oxzf ${remoteTarPath} -C ${remoteDirectory}`, sshClient);
await remoteExeCommand(`rm ${remoteTarPath}`, sshClient); await remoteExeCommand(`rm ${remoteTarPath}`, sshClient);
......
...@@ -22,11 +22,7 @@ batch_tuner.py including: ...@@ -22,11 +22,7 @@ batch_tuner.py including:
class BatchTuner class BatchTuner
""" """
import copy import logging
from enum import Enum, unique
import random
import numpy as np
import nni import nni
from nni.tuner import Tuner from nni.tuner import Tuner
...@@ -35,6 +31,7 @@ TYPE = '_type' ...@@ -35,6 +31,7 @@ TYPE = '_type'
CHOICE = 'choice' CHOICE = 'choice'
VALUE = '_value' VALUE = '_value'
logger = logging.getLogger('batch_tuner_AutoML')
class BatchTuner(Tuner): class BatchTuner(Tuner):
""" """
...@@ -46,7 +43,7 @@ class BatchTuner(Tuner): ...@@ -46,7 +43,7 @@ class BatchTuner(Tuner):
} }
} }
""" """
def __init__(self): def __init__(self):
self.count = -1 self.count = -1
self.values = [] self.values = []
...@@ -54,14 +51,14 @@ class BatchTuner(Tuner): ...@@ -54,14 +51,14 @@ class BatchTuner(Tuner):
def is_valid(self, search_space): def is_valid(self, search_space):
""" """
Check the search space is valid: only contains 'choice' type Check the search space is valid: only contains 'choice' type
Parameters Parameters
---------- ----------
search_space : dict search_space : dict
""" """
if not len(search_space) == 1: if not len(search_space) == 1:
raise RuntimeError('BatchTuner only supprt one combined-paramreters key.') raise RuntimeError('BatchTuner only supprt one combined-paramreters key.')
for param in search_space: for param in search_space:
param_type = search_space[param][TYPE] param_type = search_space[param][TYPE]
if not param_type == CHOICE: if not param_type == CHOICE:
...@@ -73,8 +70,8 @@ class BatchTuner(Tuner): ...@@ -73,8 +70,8 @@ class BatchTuner(Tuner):
return None return None
def update_search_space(self, search_space): def update_search_space(self, search_space):
"""Update the search space """Update the search space
Parameters Parameters
---------- ----------
search_space : dict search_space : dict
...@@ -88,8 +85,8 @@ class BatchTuner(Tuner): ...@@ -88,8 +85,8 @@ class BatchTuner(Tuner):
---------- ----------
parameter_id : int parameter_id : int
""" """
self.count +=1 self.count += 1
if self.count>len(self.values)-1: if self.count > len(self.values) - 1:
raise nni.NoMoreTrialError('no more parameters now.') raise nni.NoMoreTrialError('no more parameters now.')
return self.values[self.count] return self.values[self.count]
...@@ -97,4 +94,31 @@ class BatchTuner(Tuner): ...@@ -97,4 +94,31 @@ class BatchTuner(Tuner):
pass pass
def import_data(self, data): def import_data(self, data):
pass """Import additional data for tuning
Parameters
----------
data:
a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
if len(self.values) == 0:
logger.info("Search space has not been initialized, skip this data import")
return
self.values = self.values[(self.count+1):]
self.count = -1
_completed_num = 0
for trial_info in data:
logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
# simply validate data format
assert "parameter" in trial_info
_params = trial_info["parameter"]
assert "value" in trial_info
_value = trial_info['value']
if not _value:
logger.info("Useless trial data, value is %s, skip this trial data.", _value)
continue
_completed_num += 1
if _params in self.values:
self.values.remove(_params)
logger.info("Successfully import data to batch tuner, total data: %d, imported data: %d.", len(data), _completed_num)
...@@ -31,7 +31,7 @@ import ConfigSpace.hyperparameters as CSH ...@@ -31,7 +31,7 @@ import ConfigSpace.hyperparameters as CSH
from nni.protocol import CommandType, send from nni.protocol import CommandType, send
from nni.msg_dispatcher_base import MsgDispatcherBase from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.utils import OptimizeMode, extract_scalar_reward from nni.utils import OptimizeMode, extract_scalar_reward, randint_to_quniform
from .config_generator import CG_BOHB from .config_generator import CG_BOHB
...@@ -443,6 +443,7 @@ class BOHB(MsgDispatcherBase): ...@@ -443,6 +443,7 @@ class BOHB(MsgDispatcherBase):
search space of this experiment search space of this experiment
""" """
search_space = data search_space = data
randint_to_quniform(search_space)
cs = CS.ConfigurationSpace() cs = CS.ConfigurationSpace()
for var in search_space: for var in search_space:
_type = str(search_space[var]["_type"]) _type = str(search_space[var]["_type"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment