Unverified Commit cf3d434f authored by fishyds's avatar fishyds Committed by GitHub
Browse files

Add codeDir file count validation for setClusterConfig (#409)

* Add codeDir file count validation for setClusterConfig

* fix a small bug if find command is not installed

* Remove codeDir validation for local training service

* Remove useless import
parent d2f597a6
......@@ -21,6 +21,7 @@
import * as assert from 'assert';
import { randomBytes } from 'crypto';
import * as cpp from 'child-process-promise';
import * as fs from 'fs';
import * as os from 'os';
import * as path from 'path';
......@@ -32,6 +33,7 @@ import { Database, DataStore } from './datastore';
import { ExperimentStartupInfo, getExperimentId, setExperimentStartupInfo } from './experimentStartupInfo';
import { Manager } from './manager';
import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService';
import { getLogger } from './log';
function getExperimentRootDir(): string {
return path.join(os.homedir(), 'nni', 'experiments', getExperimentId());
......@@ -287,5 +289,38 @@ function getJobCancelStatus(isEarlyStopped: boolean): TrialJobStatus {
return isEarlyStopped ? 'EARLY_STOPPED' : 'USER_CANCELED';
}
export {getRemoteTmpDir, generateParamFileName, getMsgDispatcherCommand, getLogDir, getExperimentRootDir, getJobCancelStatus,
getDefaultDatabaseDir, getIPV4Address, mkDirP, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomSelect };
/**
* Utility method to calculate file numbers under a directory, recursively
* @param directory directory name
*/
function countFilesRecursively(directory: string, timeoutMilliSeconds?: number): Promise<number> {
if(!fs.existsSync(directory)) {
throw Error(`Direcotory ${directory} doesn't exist`);
}
const deferred: Deferred<number> = new Deferred<number>();
let timeoutId : NodeJS.Timer
const delayTimeout : Promise<number> = new Promise((resolve : Function, reject : Function) : void => {
// Set timeout and reject the promise once reach timeout (5 seconds)
timeoutId = setTimeout(() => {
reject(new Error(`Timeout: path ${directory} has too many files`));
}, 5000);
});
let fileCount: number = -1;
cpp.exec(`find ${directory} -type f | wc -l`).then((result) => {
if(result.stdout && parseInt(result.stdout)) {
fileCount = parseInt(result.stdout);
}
deferred.resolve(fileCount);
});
return Promise.race([deferred.promise, delayTimeout]).finally(() => {
clearTimeout(timeoutId);
});
}
export {countFilesRecursively, getRemoteTmpDir, generateParamFileName, getMsgDispatcherCommand,
getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address,
mkDirP, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomSelect };
import { getLogger } from "common/log";
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
import { countFilesRecursively } from '../../common/utils'
/**
* Validate codeDir, calculate file count recursively under codeDir, and throw error if any rule is broken
*
* @param codeDir codeDir in nni config file
* @returns file number under codeDir
*/
export async function validateCodeDir(codeDir: string) : Promise<number> {
let fileCount: number | undefined;
try {
fileCount = await countFilesRecursively(codeDir);
} catch(error) {
throw new Error(`Call count file error: ${error}`);
}
if(fileCount && fileCount > 1000) {
const errMessage: string = `Too many files(${fileCount} found}) in ${codeDir},`
+ ` please check if it's a valid code dir`;
throw new Error(errMessage);
}
return fileCount;
}
\ No newline at end of file
......@@ -40,6 +40,7 @@ import { KubeflowClusterConfig, kubeflowOperatorMap, KubeflowTrialConfig, NFSCon
import { KubeflowTrialJobDetail } from './kubeflowData';
import { KubeflowJobRestServer } from './kubeflowJobRestServer';
import { KubeflowJobInfoCollector } from './kubeflowJobInfoCollector';
import { validateCodeDir } from '../common/util';
import { AzureStorageClientUtility } from './azureStorageClientUtils';
import * as azureStorage from 'azure-storage';
......@@ -360,6 +361,15 @@ class KubeflowTrainingService implements TrainingService {
this.kubeflowTrialConfig = <KubeflowTrialConfig>JSON.parse(value);
assert(this.kubeflowClusterConfig !== undefined && this.kubeflowTrialConfig.worker !== undefined);
// Validate to make sure codeDir doesn't have too many files
try {
await validateCodeDir(this.kubeflowTrialConfig.codeDir);
} catch(error) {
this.log.error(error);
return Promise.reject(new Error(error));
}
break;
default:
break;
......
......@@ -38,12 +38,14 @@ import {
JobApplicationForm, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric, NNIManagerIpConfig
} from '../../common/trainingService';
import { delay, generateParamFileName, getExperimentRootDir, getIPV4Address, uniqueString } from '../../common/utils';
import { countFilesRecursively, delay, generateParamFileName,
getExperimentRootDir, getIPV4Address, uniqueString } from '../../common/utils';
import { PAIJobRestServer } from './paiJobRestServer'
import { PAITrialJobDetail, PAI_TRIAL_COMMAND_FORMAT, PAI_OUTPUT_DIR_FORMAT, PAI_LOG_PATH_FORMAT } from './paiData';
import { PAIJobInfoCollector } from './paiJobInfoCollector';
import { String } from 'typescript-string-operations';
import { NNIPAITrialConfig, PAIClusterConfig, PAIJobConfig, PAITaskRole } from './paiConfig';
import { validateCodeDir } from '../common/util';
var WebHDFS = require('webhdfs');
......@@ -395,6 +397,15 @@ class PAITrainingService implements TrainingService {
).replace(/\r\n|\n|\r/gm, '');
}
// Validate to make sure codeDir doesn't have too many files
try {
await validateCodeDir(this.paiTrialConfig.codeDir);
} catch(error) {
this.log.error(error);
deferred.reject(new Error(error));
break;
}
const hdfsDirContent = this.paiTrialConfig.outputDir.match(this.hdfsDirPattern);
if(hdfsDirContent === null) {
......
......@@ -48,6 +48,7 @@ import {
RemoteMachineTrialJobDetail, ScheduleResultType
} from './remoteMachineData';
import { SSHClientUtility } from './sshClientUtility';
import { validateCodeDir} from '../common/util';
/**
* Training Service implementation for Remote Machine (Linux)
......@@ -297,6 +298,15 @@ class RemoteMachineTrainingService implements TrainingService {
if (!fs.lstatSync(remoteMachineTrailConfig.codeDir).isDirectory()) {
throw new Error(`codeDir ${remoteMachineTrailConfig.codeDir} is not a directory`);
}
// Validate to make sure codeDir doesn't have too many files
try {
await validateCodeDir(remoteMachineTrailConfig.codeDir);
} catch(error) {
this.log.error(error);
return Promise.reject(new Error(error));
}
this.trialConfig = remoteMachineTrailConfig;
break;
case TrialConfigMetadataKey.MULTI_PHASE:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment