Unverified Commit 9f4485c1 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #130 from Microsoft/master

merge master
parents 1ee97350 51fbf695
......@@ -71,6 +71,7 @@ interface TrialJobDetail {
readonly workingDirectory: string;
readonly form: JobApplicationForm;
readonly sequenceId: number;
isEarlyStopped?: boolean;
}
interface HostJobDetail {
......
......@@ -99,7 +99,25 @@ class MockedDataStore implements DataStore {
private dbTrialJobs: SimpleDb = new SimpleDb('trial_jobs', './trial_jobs.json');
private dbMetrics: SimpleDb = new SimpleDb('metrics', './metrics.json');
trailJob1 = {
event: 'ADD_CUSTOMIZED',
timestamp: Date.now(),
trialJobId: "4321",
data: ''
}
metrics1 = {
timestamp: Date.now(),
trialJobId: '4321',
parameterId: 'param1',
type: 'CUSTOM',
sequence: 21,
data: ''
}
init(): Promise<void> {
this.dbTrialJobs.saveData(this.trailJob1);
this.dbMetrics.saveData(this.metrics1);
return Promise.resolve();
}
......
......@@ -19,25 +19,27 @@
'use strict';
import * as os from 'os';
import { assert, expect } from 'chai';
import { Container, Scope } from 'typescript-ioc';
import * as component from '../../common/component';
import { Database, DataStore } from '../../common/datastore';
import { Manager } from '../../common/manager';
import { Manager, ExperimentProfile} from '../../common/manager';
import { TrainingService } from '../../common/trainingService';
import { cleanupUnitTest, prepareUnitTest } from '../../common/utils';
import { NNIDataStore } from '../nniDataStore';
import { NNIManager } from '../nnimanager';
import { SqlDB } from '../sqlDatabase';
import { MockedTrainingService } from './mockedTrainingService';
import { MockedDataStore } from './mockedDatastore';
async function initContainer(): Promise<void> {
prepareUnitTest();
Container.bind(TrainingService).to(MockedTrainingService).scope(Scope.Singleton);
Container.bind(Manager).to(NNIManager).scope(Scope.Singleton);
Container.bind(Database).to(SqlDB).scope(Scope.Singleton);
Container.bind(DataStore).to(NNIDataStore).scope(Scope.Singleton);
Container.bind(DataStore).to(MockedDataStore).scope(Scope.Singleton);
await component.get<DataStore>(DataStore).init();
}
......@@ -51,9 +53,9 @@ describe('Unit test for nnimanager', function () {
let experimentParams = {
authorName: 'zql',
experimentName: 'naive_experiment',
trialConcurrency: 2,
trialConcurrency: 3,
maxExecDuration: 5,
maxTrialNum: 2,
maxTrialNum: 3,
trainingServicePlatform: 'local',
searchSpace: '{"x":1}',
tuner: {
......@@ -71,36 +73,74 @@ describe('Unit test for nnimanager', function () {
}
}
let updateExperimentParams = {
authorName: '',
experimentName: 'another_experiment',
trialConcurrency: 2,
maxExecDuration: 6,
maxTrialNum: 2,
trainingServicePlatform: 'local',
searchSpace: '{"y":2}',
tuner: {
className: 'TPE',
classArgs: {
optimize_mode: 'maximize'
},
checkpointDir: '',
gpuNum: 0
},
assessor: {
className: 'Medianstop',
checkpointDir: '',
gpuNum: 1
}
}
let experimentProfile = {
params: updateExperimentParams,
id: 'test',
execDuration: 0,
maxSequenceId: 0,
revision: 0
}
before(async () => {
await initContainer();
nniManager = component.get(Manager);
const expId: string = await nniManager.startExperiment(experimentParams);
assert(expId);
});
assert.strictEqual(expId, 'unittest');
})
after(async () => {
await nniManager.stopExperiment();
await setTimeout(() => {nniManager.stopExperiment()},15000);
cleanupUnitTest();
})
it('test resumeExperiment', () => {
//TODO: add resume experiment unit test
it('test addCustomizedTrialJob', () => {
return nniManager.addCustomizedTrialJob('hyperParams').then(() => {
}).catch((error) => {
assert.fail(error);
})
})
it('test listTrialJobs', () => {
//FIXME: not implemented
//return nniManager.listTrialJobs().then(function (trialJobDetails) {
// expect(trialJobDetails.length).to.be.equal(2);
//}).catch(function (error) {
// assert.fail(error);
//})
return nniManager.listTrialJobs().then(function (trialjobdetails) {
expect(trialjobdetails.length).to.be.equal(2);
}).catch((error) => {
assert.fail(error);
})
})
it('test getTrialJob valid', () => {
//query a exist id
return nniManager.getTrialJob('1234').then(function (trialJobDetail) {
expect(trialJobDetail.id).to.be.equal('1234');
}).catch(function (error) {
}).catch((error) => {
assert.fail(error);
})
})
......@@ -132,7 +172,6 @@ describe('Unit test for nnimanager', function () {
})
})
//TODO: complete ut
it('test cancelTrialJobByUser', () => {
return nniManager.cancelTrialJobByUser('1234').then(() => {
......@@ -141,11 +180,112 @@ describe('Unit test for nnimanager', function () {
})
})
it('test addCustomizedTrialJob', () => {
return nniManager.addCustomizedTrialJob('hyperParams').then(() => {
it('test getExperimentProfile', () => {
return nniManager.getExperimentProfile().then((experimentProfile) => {
expect(experimentProfile.id).to.be.equal('unittest');
expect(experimentProfile.logDir).to.be.equal(os.homedir()+'/nni/experiments/unittest');
}).catch((error) => {
assert.fail(error);
})
})
it('test updateExperimentProfile TRIAL_CONCURRENCY', () => {
return nniManager.updateExperimentProfile(experimentProfile, 'TRIAL_CONCURRENCY').then(() => {
nniManager.getExperimentProfile().then((updateProfile) => {
expect(updateProfile.params.trialConcurrency).to.be.equal(2);
});
}).catch((error) => {
assert.fail(error);
})
})
it('test updateExperimentProfile MAX_EXEC_DURATION', () => {
return nniManager.updateExperimentProfile(experimentProfile, 'MAX_EXEC_DURATION').then(() => {
nniManager.getExperimentProfile().then((updateProfile) => {
expect(updateProfile.params.maxExecDuration).to.be.equal(6);
});
}).catch((error) => {
assert.fail(error);
})
})
it('test updateExperimentProfile SEARCH_SPACE', () => {
return nniManager.updateExperimentProfile(experimentProfile, 'SEARCH_SPACE').then(() => {
nniManager.getExperimentProfile().then((updateProfile) => {
expect(updateProfile.params.searchSpace).to.be.equal('{"y":2}');
});
}).catch((error) => {
assert.fail(error);
})
})
it('test updateExperimentProfile MAX_TRIAL_NUM', () => {
return nniManager.updateExperimentProfile(experimentProfile, 'MAX_TRIAL_NUM').then(() => {
nniManager.getExperimentProfile().then((updateProfile) => {
expect(updateProfile.params.maxTrialNum).to.be.equal(2);
});
}).catch((error) => {
assert.fail(error);
})
})
it('test getStatus', () => {
assert.strictEqual(nniManager.getStatus().status,'RUNNING');
})
it('test getMetricData with trialJobId', () => {
//query a exist trialJobId
return nniManager.getMetricData('4321', 'CUSTOM').then((metricData) => {
expect(metricData.length).to.be.equal(1);
expect(metricData[0].trialJobId).to.be.equal('4321');
expect(metricData[0].parameterId).to.be.equal('param1');
}).catch((error) => {
assert.fail(error);
})
})
it('test getMetricData with invalid trialJobId', () => {
//query an invalid trialJobId
return nniManager.getMetricData('43210', 'CUSTOM').then((metricData) => {
assert.fail();
}).catch((error) => {
})
})
it('test getTrialJobStatistics', () => {
// get 3 trial jobs (init, addCustomizedTrialJob, cancelTrialJobByUser)
return nniManager.getTrialJobStatistics().then(function (trialJobStatistics) {
expect(trialJobStatistics.length).to.be.equal(2);
if (trialJobStatistics[0].trialJobStatus === 'WAITING') {
expect(trialJobStatistics[0].trialJobNumber).to.be.equal(2);
expect(trialJobStatistics[1].trialJobNumber).to.be.equal(1);
}
else {
expect(trialJobStatistics[1].trialJobNumber).to.be.equal(2);
expect(trialJobStatistics[0].trialJobNumber).to.be.equal(1);
}
}).catch((error) => {
assert.fail(error);
})
})
it('test addCustomizedTrialJob reach maxTrialNum', () => {
// test currSubmittedTrialNum reach maxTrialNum
return nniManager.addCustomizedTrialJob('hyperParam').then(() => {
nniManager.getTrialJobStatistics().then(function (trialJobStatistics) {
if (trialJobStatistics[0].trialJobStatus === 'WAITING')
expect(trialJobStatistics[0].trialJobNumber).to.be.equal(2);
else
expect(trialJobStatistics[1].trialJobNumber).to.be.equal(2);
})
}).catch((error) => {
assert.fail(error);
})
})
it('test resumeExperiment', async () => {
//TODO: add resume experiment unit test
})
})
......@@ -5,7 +5,7 @@
"scripts": {
"postbuild": "cp -rf config ./dist/",
"build": "tsc",
"test": "nyc mocha -r ts-node/register -t 15000 --recursive **/*.test.ts --exclude node_modules/**/**/*.test.ts --exclude core/test/nnimanager.test.ts --colors",
"test": "nyc mocha -r ts-node/register -t 15000 --recursive **/*.test.ts --exclude node_modules/**/**/*.test.ts --colors",
"start": "node dist/main.js",
"tslint": "tslint -p ."
},
......
......@@ -46,6 +46,7 @@ export class KubeflowJobInfoCollector extends KubernetesJobInfoCollector{
try {
kubernetesJobInfo = await kubernetesCRDClient.getKubernetesJob(kubernetesTrialJob.kubernetesJobName);
} catch(error) {
// Notice: it maynot be a 'real' error since cancel trial job can also cause getKubernetesJob failed.
this.log.error(`Get job ${kubernetesTrialJob.kubernetesJobName} info failed, error is ${error}`);
//This is not treat as a error status
return Promise.resolve();
......
......@@ -255,7 +255,7 @@ class LocalTrainingService implements TrainingService {
}
if (trialJob.pid === undefined){
this.setTrialJobStatus(trialJob, 'USER_CANCELED');
return;
return Promise.resolve();
}
if (trialJob.form.jobType === 'TRIAL') {
await tkill(trialJob.pid, 'SIGKILL');
......@@ -265,6 +265,7 @@ class LocalTrainingService implements TrainingService {
throw new Error(`Job type not supported: ${trialJob.form.jobType}`);
}
this.setTrialJobStatus(trialJob, getJobCancelStatus(isEarlyStopped));
return Promise.resolve();
}
public async setClusterMetadata(key: string, value: string): Promise<void> {
......
......@@ -34,6 +34,7 @@ export class PAITrialJobDetail implements TrialJobDetail {
public form: JobApplicationForm;
public sequenceId: number;
public hdfsLogPath: string;
public isEarlyStopped?: boolean;
constructor(id: string, status: TrialJobStatus, paiJobName : string,
submitTime: number, workingDirectory: string, form: JobApplicationForm, sequenceId: number, hdfsLogPath: string) {
......@@ -63,7 +64,7 @@ export const PAI_TRIAL_COMMAND_FORMAT: string =
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4}
&& cd $NNI_SYS_DIR && sh install_nni.sh
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{5}' --nnimanager_ip '{6}' --nnimanager_port '{7}'
--pai_hdfs_output_dir '{8}' --pai_hdfs_host '{9}' --pai_user_name {10} --nni_hdfs_exp_dir '{11}'`;
--pai_hdfs_output_dir '{8}' --pai_hdfs_host '{9}' --pai_user_name {10} --nni_hdfs_exp_dir '{11}' --webhdfs_path '/webhdfs/api/v1'`;
export const PAI_OUTPUT_DIR_FORMAT: string =
`hdfs://{0}:9000/`;
......
......@@ -103,8 +103,12 @@ export class PAIJobInfoCollector {
paiTrialJob.status = 'SUCCEEDED';
break;
case 'STOPPED':
if (paiTrialJob.status !== 'EARLY_STOPPED') {
paiTrialJob.status = 'USER_CANCELED';
if (paiTrialJob.isEarlyStopped !== undefined) {
paiTrialJob.status = paiTrialJob.isEarlyStopped === true ?
'EARLY_STOPPED' : 'USER_CANCELED';
} else {
// if paiTrialJob's isEarlyStopped is undefined, that mean we didn't stop it via cancellation, mark it as SYS_CANCELLED by PAI
paiTrialJob.status = 'SYS_CANCELED';
}
break;
case 'FAILED':
......
......@@ -324,14 +324,15 @@ class PAITrainingService implements TrainingService {
"Authorization": 'Bearer ' + this.paiToken
}
};
// Set trialjobDetail's early stopped field, to mark the job's cancellation source
trialJobDetail.isEarlyStopped = isEarlyStopped;
request(stopJobRequest, (error: Error, response: request.Response, body: any) => {
if (error || response.statusCode >= 400) {
this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`);
deferred.reject(error ? error.message : 'Stop trial failed, http code: ' + response.statusCode);
} else {
if (isEarlyStopped) {
trialJobDetail.status = 'EARLY_STOPPED';
}
deferred.resolve();
}
});
......
......@@ -80,6 +80,7 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail {
public form: JobApplicationForm;
public sequenceId: number;
public rmMeta?: RemoteMachineMeta;
public isEarlyStopped?: boolean;
constructor(id: string, status: TrialJobStatus, submitTime: number,
workingDirectory: string, form: JobApplicationForm, sequenceId: number) {
......@@ -114,7 +115,7 @@ export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={
cd $NNI_SYS_DIR
sh install_nni.sh
echo $$ >{6}
python3 -m nni_trial_tool.trial_keeper --trial_command '{7}' --nnimanager_ip '{8}' --nnimanager_port '{9}'
python3 -m nni_trial_tool.trial_keeper --trial_command '{7}' --nnimanager_ip '{8}' --nnimanager_port '{9}' 1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr
echo $? \`date +%s%3N\` >{10}`;
export const HOST_JOB_SHELL_FORMAT: string =
......
......@@ -48,7 +48,7 @@ import {
GPU_COLLECTOR_FORMAT
} from './remoteMachineData';
import { SSHClientUtility } from './sshClientUtility';
import { validateCodeDir} from '../common/util';
import { validateCodeDir } from '../common/util';
import { RemoteMachineJobRestServer } from './remoteMachineJobRestServer';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { mkDirP } from '../../common/utils';
......@@ -279,8 +279,9 @@ class RemoteMachineTrainingService implements TrainingService {
const jobpidPath: string = this.getJobPidPath(trialJob.id);
try {
// Mark the toEarlyStop tag here
trialJob.isEarlyStopped = isEarlyStopped;
await SSHClientUtility.remoteExeCommand(`pkill -P \`cat ${jobpidPath}\``, sshClient);
trialJob.status = getJobCancelStatus(isEarlyStopped);
} catch (error) {
// Not handle the error since pkill failed will not impact trial job's current status
this.log.error(`remoteTrainingService.cancelTrialJob: ${error.message}`);
......@@ -482,6 +483,11 @@ class RemoteMachineTrainingService implements TrainingService {
if (trialJobDetail === undefined) {
throw new NNIError(NNIErrorNames.INVALID_JOB_DETAIL, `Invalid job detail information for trial job ${trialJobId}`);
}
// If job is not WATIING, Don't prepare and resolve true immediately
if (trialJobDetail.status !== 'WAITING') {
deferred.resolve(true);
return deferred.promise;
}
// get an ssh client from scheduler
const rmScheduleResult: RemoteMachineScheduleResult = this.gpuScheduler.scheduleMachine(this.trialConfig.gpuNum, trialJobId);
if (rmScheduleResult.resultType === ScheduleResultType.REQUIRE_EXCEED_TOTAL) {
......@@ -640,7 +646,12 @@ class RemoteMachineTrainingService implements TrainingService {
if (parseInt(code, 10) === 0) {
trialJob.status = 'SUCCEEDED';
} else {
// isEarlyStopped is never set, mean it's not cancelled by NNI, so if the process's exit code >0, mark it as FAILED
if (trialJob.isEarlyStopped === undefined) {
trialJob.status = 'FAILED';
} else {
trialJob.status = getJobCancelStatus(trialJob.isEarlyStopped);
}
}
trialJob.endTime = parseInt(timestamp, 10);
}
......
......@@ -19,14 +19,106 @@
'use strict';
import { TrainingService } from '../../common/trainingService';
import { LocalTrainingService } from '../local/localTrainingService';
import * as assert from 'assert';
import * as chai from 'chai';
import * as chaiAsPromised from 'chai-as-promised';
import * as fs from 'fs';
import * as tmp from 'tmp';
import * as component from '../../common/component';
import { TrialJobApplicationForm, TrialJobDetail, TrainingService } from '../../common/trainingService';
import { cleanupUnitTest, delay, prepareUnitTest } from '../../common/utils';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { LocalTrainingServiceForGPU } from '../local/localTrainingServiceForGPU';
// TODO: copy mockedTrail.py to local folder
const localCodeDir: string = tmp.dirSync().name
const mockedTrialPath: string = './training_service/test/mockedTrial.py'
fs.copyFileSync(mockedTrialPath, localCodeDir + '/mockedTrial.py')
describe('Unit Test for LocalTrainingService', () => {
let trainingService: TrainingService
let trialConfig: any = `{"command":"sleep 1h && echo hello","codeDir":"${localCodeDir}","gpuNum":1}`
let localTrainingService: LocalTrainingServiceForGPU;
before(() => {
chai.should();
chai.use(chaiAsPromised);
prepareUnitTest();
});
after(() => {
cleanupUnitTest();
});
beforeEach(() => {
localTrainingService = component.get(LocalTrainingServiceForGPU);
localTrainingService.run();
});
afterEach(() => {
localTrainingService.cleanUp();
});
it('List empty trial jobs', async () => {
//trial jobs should be empty, since there are no submitted jobs
chai.expect(await localTrainingService.listTrialJobs()).to.be.empty;
});
it('setClusterMetadata and getClusterMetadata', async () => {
await localTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, trialConfig);
localTrainingService.getClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG).then((data)=>{
chai.expect(data).to.be.equals(trialConfig);
});
});
it('Submit job and Cancel job', async () => {
await localTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, trialConfig);
// submit job
const form: TrialJobApplicationForm = {
jobType: 'TRIAL',
hyperParameters: {
value: 'mock hyperparameters',
index: 0
}
};
const jobDetail: TrialJobDetail = await localTrainingService.submitTrialJob(form);
chai.expect(jobDetail.status).to.be.equals('WAITING');
await localTrainingService.cancelTrialJob(jobDetail.id);
chai.expect(jobDetail.status).to.be.equals('USER_CANCELED');
}).timeout(20000);
it('Read metrics, Add listener, and remove listener', async () => {
// set meta data
const trialConfig: string = `{\"command\":\"python3 mockedTrial.py\", \"codeDir\":\"${localCodeDir}\",\"gpuNum\":0}`
await localTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, trialConfig);
// submit job
const form: TrialJobApplicationForm = {
jobType: 'TRIAL',
hyperParameters: {
value: 'mock hyperparameters',
index: 0
}
};
const jobDetail: TrialJobDetail = await localTrainingService.submitTrialJob(form);
chai.expect(jobDetail.status).to.be.equals('WAITING');
localTrainingService.listTrialJobs().then((jobList)=>{
chai.expect(jobList.length).to.be.equals(1);
});
// Add metrics listeners
const listener1 = function f1(metric: any) {
chai.expect(metric.id).to.be.equals(jobDetail.id);
}
localTrainingService.addTrialJobMetricListener(listener1);
// Wait to collect metric
await delay(1000);
await localTrainingService.cancelTrialJob(jobDetail.id);
localTrainingService.removeTrialJobMetricListener(listener1);
}).timeout(20000);
beforeEach(async () => {
trainingService = component.get(LocalTrainingService);
it('Test multiphaseSupported', () => {
chai.expect(localTrainingService.isMultiPhaseJobSupported).to.be.equals(true)
})
});
\ No newline at end of file
......@@ -182,6 +182,7 @@ class SlideBar extends React.Component<{}, SliderState> {
render() {
const { version, menuVisible } = this.state;
const feed = `https://github.com/Microsoft/nni/issues/new?labels=${version}`;
const menu = (
<Menu onClick={this.handleMenuClick}>
<Menu.Item key="1">Experiment Parameters</Menu.Item>
......@@ -221,7 +222,7 @@ class SlideBar extends React.Component<{}, SliderState> {
Download <Icon type="down" />
</a>
</Dropdown>
<a href="https://github.com/Microsoft/nni/issues/new?labels=v0.5.1" target="_blank">
<a href={feed} target="_blank">
<img
src={require('../static/img/icon/issue.png')}
alt="NNI github issue"
......
......@@ -38,7 +38,6 @@ def gen_new_config(config_file, training_service='local'):
new_config_file = config_file + '.tmp'
ts = get_yml_content('training_service.yml')[training_service]
print(config)
print(ts)
# hack for kubeflow trial config
......@@ -64,7 +63,6 @@ def run_test(config_file, training_service, local_gpu=False):
return
try:
print('Testing %s...' % config_file)
proc = subprocess.run(['nnictl', 'create', '--config', new_config_file])
assert proc.returncode == 0, '`nnictl create` failed with code %d' % proc.returncode
......@@ -109,8 +107,10 @@ def run(args):
try:
# sleep 5 seconds here, to make sure previous stopped exp has enough time to exit to avoid port conflict
time.sleep(5)
print(GREEN + 'Testing:' + config_file + CLEAR)
begin_time = time.time()
run_test(config_file, args.ts, args.local_gpu)
print(GREEN + 'Test %s: TEST PASS' % (config_file) + CLEAR)
print(GREEN + 'Test %s: TEST PASS IN %d mins' % (config_file, (time.time() - begin_time)/60) + CLEAR)
except Exception as error:
print(RED + 'Test %s: TEST FAIL' % (config_file) + CLEAR)
print('%r' % error)
......
jobs:
- job: 'integration_test_kubeflow'
timeoutInMinutes: 0
pool: 'NNI CI KUBE CLI'
variables:
......
jobs:
- job: 'integration_test_pai'
timeoutInMinutes: 0
pool: 'NNI CI PAI CLI'
variables:
......
......@@ -168,7 +168,9 @@ def set_remote_config(experiment_config, port, config_file_name):
with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message
result, message = setNNIManagerIp(experiment_config, port, config_file_name)
if not result:
return result, message
#set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message
......
......@@ -48,10 +48,25 @@ def main_loop(args):
# redirect trial keeper's stdout and stderr to syslog
trial_syslogger_stdout = RemoteLogger(args.nnimanager_ip, args.nnimanager_port, 'trial', StdOutputType.Stdout)
sys.stdout = sys.stderr = trial_keeper_syslogger
if args.pai_hdfs_host is not None and args.nni_hdfs_exp_dir is not None:
# backward compatibility
hdfs_host = None
hdfs_output_dir = None
if args.hdfs_host:
hdfs_host = args.hdfs_host
elif args.pai_hdfs_host:
hdfs_host = args.pai_hdfs_host
if args.hdfs_output_dir:
hdfs_output_dir = args.hdfs_output_dir
elif args.pai_hdfs_output_dir:
hdfs_output_dir = args.pai_hdfs_output_dir
if hdfs_host is not None and args.nni_hdfs_exp_dir is not None:
try:
hdfs_client = HdfsClient(hosts='{0}:{1}'.format(args.pai_hdfs_host, '50070'), user_name=args.pai_user_name, timeout=5)
if args.webhdfs_path:
hdfs_client = HdfsClient(hosts='{0}:80'.format(hdfs_host), user_name=args.pai_user_name, webhdfs_path=args.webhdfs_path, timeout=5)
else:
# backward compatibility
hdfs_client = HdfsClient(hosts='{0}:{1}'.format(hdfs_host, '50070'), user_name=args.pai_user_name, timeout=5)
except Exception as e:
nni_log(LogType.Error, 'Create HDFS client error: ' + str(e))
raise e
......@@ -67,14 +82,14 @@ def main_loop(args):
# child worker process exits and all stdout data is read
if retCode is not None and log_pipe_stdout.set_process_exit() and log_pipe_stdout.is_read_completed == True:
nni_log(LogType.Info, 'subprocess terminated. Exit code is {}. Quit'.format(retCode))
if args.pai_hdfs_output_dir is not None:
if hdfs_output_dir is not None:
# Copy local directory to hdfs for OpenPAI
nni_local_output_dir = os.environ['NNI_OUTPUT_DIR']
try:
if copyDirectoryToHdfs(nni_local_output_dir, args.pai_hdfs_output_dir, hdfs_client):
nni_log(LogType.Info, 'copy directory from {0} to {1} success!'.format(nni_local_output_dir, args.pai_hdfs_output_dir))
if copyDirectoryToHdfs(nni_local_output_dir, hdfs_output_dir, hdfs_client):
nni_log(LogType.Info, 'copy directory from {0} to {1} success!'.format(nni_local_output_dir, hdfs_output_dir))
else:
nni_log(LogType.Info, 'copy directory from {0} to {1} failed!'.format(nni_local_output_dir, args.pai_hdfs_output_dir))
nni_log(LogType.Info, 'copy directory from {0} to {1} failed!'.format(nni_local_output_dir, hdfs_output_dir))
except Exception as e:
nni_log(LogType.Error, 'HDFS copy directory got exception: ' + str(e))
raise e
......@@ -95,10 +110,13 @@ if __name__ == '__main__':
PARSER.add_argument('--trial_command', type=str, help='Command to launch trial process')
PARSER.add_argument('--nnimanager_ip', type=str, default='localhost', help='NNI manager rest server IP')
PARSER.add_argument('--nnimanager_port', type=str, default='8081', help='NNI manager rest server port')
PARSER.add_argument('--pai_hdfs_output_dir', type=str, help='the output dir of hdfs')
PARSER.add_argument('--pai_hdfs_host', type=str, help='the host of hdfs')
PARSER.add_argument('--pai_hdfs_output_dir', type=str, help='the output dir of pai_hdfs') # backward compatibility
PARSER.add_argument('--hdfs_output_dir', type=str, help='the output dir of hdfs')
PARSER.add_argument('--pai_hdfs_host', type=str, help='the host of pai_hdfs') # backward compatibility
PARSER.add_argument('--hdfs_host', type=str, help='the host of hdfs')
PARSER.add_argument('--pai_user_name', type=str, help='the username of hdfs')
PARSER.add_argument('--nni_hdfs_exp_dir', type=str, help='nni experiment directory in hdfs')
PARSER.add_argument('--webhdfs_path', type=str, help='the webhdfs path used in webhdfs URL')
args, unknown = PARSER.parse_known_args()
if args.trial_command is None:
exit(1)
......
......@@ -12,7 +12,7 @@ setuptools.setup(
'psutil',
'astor',
'schema',
'pyhdfs'
'PythonWebHDFS'
],
author = 'Microsoft NNI Team',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment