Unverified Commit cf3d434f authored by fishyds's avatar fishyds Committed by GitHub
Browse files

Add codeDir file count validation for setClusterConfig (#409)

* Add codeDir file count validation for setClusterConfig

* fix a small bug if find command is not installed

* Remove codeDir validation for local training service

* Remove useless import
parent d2f597a6
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
import * as assert from 'assert'; import * as assert from 'assert';
import { randomBytes } from 'crypto'; import { randomBytes } from 'crypto';
import * as cpp from 'child-process-promise';
import * as fs from 'fs'; import * as fs from 'fs';
import * as os from 'os'; import * as os from 'os';
import * as path from 'path'; import * as path from 'path';
...@@ -32,6 +33,7 @@ import { Database, DataStore } from './datastore'; ...@@ -32,6 +33,7 @@ import { Database, DataStore } from './datastore';
import { ExperimentStartupInfo, getExperimentId, setExperimentStartupInfo } from './experimentStartupInfo'; import { ExperimentStartupInfo, getExperimentId, setExperimentStartupInfo } from './experimentStartupInfo';
import { Manager } from './manager'; import { Manager } from './manager';
import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService'; import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService';
import { getLogger } from './log';
function getExperimentRootDir(): string { function getExperimentRootDir(): string {
return path.join(os.homedir(), 'nni', 'experiments', getExperimentId()); return path.join(os.homedir(), 'nni', 'experiments', getExperimentId());
...@@ -287,5 +289,38 @@ function getJobCancelStatus(isEarlyStopped: boolean): TrialJobStatus { ...@@ -287,5 +289,38 @@ function getJobCancelStatus(isEarlyStopped: boolean): TrialJobStatus {
return isEarlyStopped ? 'EARLY_STOPPED' : 'USER_CANCELED'; return isEarlyStopped ? 'EARLY_STOPPED' : 'USER_CANCELED';
} }
export {getRemoteTmpDir, generateParamFileName, getMsgDispatcherCommand, getLogDir, getExperimentRootDir, getJobCancelStatus, /**
getDefaultDatabaseDir, getIPV4Address, mkDirP, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomSelect }; * Utility method to calculate file numbers under a directory, recursively
* @param directory directory name
*/
function countFilesRecursively(directory: string, timeoutMilliSeconds?: number): Promise<number> {
if(!fs.existsSync(directory)) {
throw Error(`Direcotory ${directory} doesn't exist`);
}
const deferred: Deferred<number> = new Deferred<number>();
let timeoutId : NodeJS.Timer
const delayTimeout : Promise<number> = new Promise((resolve : Function, reject : Function) : void => {
// Set timeout and reject the promise once reach timeout (5 seconds)
timeoutId = setTimeout(() => {
reject(new Error(`Timeout: path ${directory} has too many files`));
}, 5000);
});
let fileCount: number = -1;
cpp.exec(`find ${directory} -type f | wc -l`).then((result) => {
if(result.stdout && parseInt(result.stdout)) {
fileCount = parseInt(result.stdout);
}
deferred.resolve(fileCount);
});
return Promise.race([deferred.promise, delayTimeout]).finally(() => {
clearTimeout(timeoutId);
});
}
export {countFilesRecursively, getRemoteTmpDir, generateParamFileName, getMsgDispatcherCommand,
getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address,
mkDirP, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomSelect };
import { getLogger } from "common/log";
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
import { countFilesRecursively } from '../../common/utils'
/**
* Validate codeDir, calculate file count recursively under codeDir, and throw error if any rule is broken
*
* @param codeDir codeDir in nni config file
* @returns file number under codeDir
*/
export async function validateCodeDir(codeDir: string) : Promise<number> {
let fileCount: number | undefined;
try {
fileCount = await countFilesRecursively(codeDir);
} catch(error) {
throw new Error(`Call count file error: ${error}`);
}
if(fileCount && fileCount > 1000) {
const errMessage: string = `Too many files(${fileCount} found}) in ${codeDir},`
+ ` please check if it's a valid code dir`;
throw new Error(errMessage);
}
return fileCount;
}
\ No newline at end of file
...@@ -40,6 +40,7 @@ import { KubeflowClusterConfig, kubeflowOperatorMap, KubeflowTrialConfig, NFSCon ...@@ -40,6 +40,7 @@ import { KubeflowClusterConfig, kubeflowOperatorMap, KubeflowTrialConfig, NFSCon
import { KubeflowTrialJobDetail } from './kubeflowData'; import { KubeflowTrialJobDetail } from './kubeflowData';
import { KubeflowJobRestServer } from './kubeflowJobRestServer'; import { KubeflowJobRestServer } from './kubeflowJobRestServer';
import { KubeflowJobInfoCollector } from './kubeflowJobInfoCollector'; import { KubeflowJobInfoCollector } from './kubeflowJobInfoCollector';
import { validateCodeDir } from '../common/util';
import { AzureStorageClientUtility } from './azureStorageClientUtils'; import { AzureStorageClientUtility } from './azureStorageClientUtils';
import * as azureStorage from 'azure-storage'; import * as azureStorage from 'azure-storage';
...@@ -360,6 +361,15 @@ class KubeflowTrainingService implements TrainingService { ...@@ -360,6 +361,15 @@ class KubeflowTrainingService implements TrainingService {
this.kubeflowTrialConfig = <KubeflowTrialConfig>JSON.parse(value); this.kubeflowTrialConfig = <KubeflowTrialConfig>JSON.parse(value);
assert(this.kubeflowClusterConfig !== undefined && this.kubeflowTrialConfig.worker !== undefined); assert(this.kubeflowClusterConfig !== undefined && this.kubeflowTrialConfig.worker !== undefined);
// Validate to make sure codeDir doesn't have too many files
try {
await validateCodeDir(this.kubeflowTrialConfig.codeDir);
} catch(error) {
this.log.error(error);
return Promise.reject(new Error(error));
}
break; break;
default: default:
break; break;
......
...@@ -38,12 +38,14 @@ import { ...@@ -38,12 +38,14 @@ import {
JobApplicationForm, TrainingService, TrialJobApplicationForm, JobApplicationForm, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric, NNIManagerIpConfig TrialJobDetail, TrialJobMetric, NNIManagerIpConfig
} from '../../common/trainingService'; } from '../../common/trainingService';
import { delay, generateParamFileName, getExperimentRootDir, getIPV4Address, uniqueString } from '../../common/utils'; import { countFilesRecursively, delay, generateParamFileName,
getExperimentRootDir, getIPV4Address, uniqueString } from '../../common/utils';
import { PAIJobRestServer } from './paiJobRestServer' import { PAIJobRestServer } from './paiJobRestServer'
import { PAITrialJobDetail, PAI_TRIAL_COMMAND_FORMAT, PAI_OUTPUT_DIR_FORMAT, PAI_LOG_PATH_FORMAT } from './paiData'; import { PAITrialJobDetail, PAI_TRIAL_COMMAND_FORMAT, PAI_OUTPUT_DIR_FORMAT, PAI_LOG_PATH_FORMAT } from './paiData';
import { PAIJobInfoCollector } from './paiJobInfoCollector'; import { PAIJobInfoCollector } from './paiJobInfoCollector';
import { String } from 'typescript-string-operations'; import { String } from 'typescript-string-operations';
import { NNIPAITrialConfig, PAIClusterConfig, PAIJobConfig, PAITaskRole } from './paiConfig'; import { NNIPAITrialConfig, PAIClusterConfig, PAIJobConfig, PAITaskRole } from './paiConfig';
import { validateCodeDir } from '../common/util';
var WebHDFS = require('webhdfs'); var WebHDFS = require('webhdfs');
...@@ -395,6 +397,15 @@ class PAITrainingService implements TrainingService { ...@@ -395,6 +397,15 @@ class PAITrainingService implements TrainingService {
).replace(/\r\n|\n|\r/gm, ''); ).replace(/\r\n|\n|\r/gm, '');
} }
// Validate to make sure codeDir doesn't have too many files
try {
await validateCodeDir(this.paiTrialConfig.codeDir);
} catch(error) {
this.log.error(error);
deferred.reject(new Error(error));
break;
}
const hdfsDirContent = this.paiTrialConfig.outputDir.match(this.hdfsDirPattern); const hdfsDirContent = this.paiTrialConfig.outputDir.match(this.hdfsDirPattern);
if(hdfsDirContent === null) { if(hdfsDirContent === null) {
......
...@@ -48,6 +48,7 @@ import { ...@@ -48,6 +48,7 @@ import {
RemoteMachineTrialJobDetail, ScheduleResultType RemoteMachineTrialJobDetail, ScheduleResultType
} from './remoteMachineData'; } from './remoteMachineData';
import { SSHClientUtility } from './sshClientUtility'; import { SSHClientUtility } from './sshClientUtility';
import { validateCodeDir} from '../common/util';
/** /**
* Training Service implementation for Remote Machine (Linux) * Training Service implementation for Remote Machine (Linux)
...@@ -297,6 +298,15 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -297,6 +298,15 @@ class RemoteMachineTrainingService implements TrainingService {
if (!fs.lstatSync(remoteMachineTrailConfig.codeDir).isDirectory()) { if (!fs.lstatSync(remoteMachineTrailConfig.codeDir).isDirectory()) {
throw new Error(`codeDir ${remoteMachineTrailConfig.codeDir} is not a directory`); throw new Error(`codeDir ${remoteMachineTrailConfig.codeDir} is not a directory`);
} }
// Validate to make sure codeDir doesn't have too many files
try {
await validateCodeDir(remoteMachineTrailConfig.codeDir);
} catch(error) {
this.log.error(error);
return Promise.reject(new Error(error));
}
this.trialConfig = remoteMachineTrailConfig; this.trialConfig = remoteMachineTrailConfig;
break; break;
case TrialConfigMetadataKey.MULTI_PHASE: case TrialConfigMetadataKey.MULTI_PHASE:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment