You need to sign in or sign up before continuing.
Unverified Commit db19946d authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Support AAD token login in PAI mode (#1660)

parent 52b93d0c
...@@ -82,6 +82,22 @@ Compared with [LocalMode](LocalMode.md) and [RemoteMachineMode](RemoteMachineMod ...@@ -82,6 +82,22 @@ Compared with [LocalMode](LocalMode.md) and [RemoteMachineMode](RemoteMachineMod
portNumber: 1 portNumber: 1
``` ```
NNI support two kind of authorization method in PAI, including password and PAI token, [refer](https://github.com/microsoft/pai/blob/b6bd2ab1c8890f91b7ac5859743274d2aa923c22/docs/rest-server/API.md#2-authentication). The authorization is configured in `paiConfig` field.
For password authorization, the `paiConfig` schema is:
```
paiConfig:
userName: your_pai_nni_user
passWord: your_pai_password
host: 10.1.1.1
```
For pai token authorization, the `paiConfig` schema is:
```
paiConfig:
userName: your_pai_nni_user
token: your_pai_token
host: 10.1.1.1
```
Once complete to fill NNI experiment config file and save (for example, save as exp_pai.yml), then run the following command Once complete to fill NNI experiment config file and save (for example, save as exp_pai.yml), then run the following command
``` ```
nnictl create --config exp_pai.yml nnictl create --config exp_pai.yml
......
...@@ -107,7 +107,8 @@ export namespace ValidationSchemas { ...@@ -107,7 +107,8 @@ export namespace ValidationSchemas {
}), }),
pai_config: joi.object({ pai_config: joi.object({
userName: joi.string().min(1).required(), userName: joi.string().min(1).required(),
passWord: joi.string().min(1).required(), passWord: joi.string().min(1),
token: joi.string().min(1),
host: joi.string().min(1).required() host: joi.string().min(1).required()
}), }),
kubeflow_config: joi.object({ kubeflow_config: joi.object({
......
...@@ -107,19 +107,22 @@ export class PAIJobConfig { ...@@ -107,19 +107,22 @@ export class PAIJobConfig {
*/ */
export class PAIClusterConfig { export class PAIClusterConfig {
public readonly userName: string; public readonly userName: string;
public readonly passWord: string; public readonly passWord?: string;
public readonly host: string; public readonly host: string;
public readonly token?: string;
/** /**
* Constructor * Constructor
* @param userName User name of PAI Cluster * @param userName User name of PAI Cluster
* @param passWord password of PAI Cluster * @param passWord password of PAI Cluster
* @param host Host IP of PAI Cluster * @param host Host IP of PAI Cluster
* @param token PAI token of PAI Cluster
*/ */
constructor(userName: string, passWord : string, host : string) { constructor(userName: string, host : string, passWord?: string, token?: string) {
this.userName = userName; this.userName = userName;
this.passWord = passWord; this.passWord = passWord;
this.host = host; this.host = host;
this.token = token;
} }
} }
......
...@@ -208,7 +208,7 @@ class PAITrainingService implements TrainingService { ...@@ -208,7 +208,7 @@ class PAITrainingService implements TrainingService {
const stopJobRequest: request.Options = { const stopJobRequest: request.Options = {
uri: `http://${this.paiClusterConfig.host}/rest-server/api/v1/user/${this.paiClusterConfig.userName}\ uri: `http://${this.paiClusterConfig.host}/rest-server/api/v1/user/${this.paiClusterConfig.userName}\
/jobs/${trialJobDetail.paiJobName}/executionType`, /jobs/${trialJobDetail.paiJobName}/executionType`,
method: 'PUT', method: 'PUT',
json: true, json: true,
body: {value: 'STOP'}, body: {value: 'STOP'},
...@@ -256,9 +256,15 @@ class PAITrainingService implements TrainingService { ...@@ -256,9 +256,15 @@ class PAITrainingService implements TrainingService {
path: '/webhdfs/api/v1', path: '/webhdfs/api/v1',
host: this.paiClusterConfig.host host: this.paiClusterConfig.host
}); });
if(this.paiClusterConfig.passWord) {
// Get PAI authentication token
await this.updatePaiToken();
} else if(this.paiClusterConfig.token) {
this.paiToken = this.paiClusterConfig.token;
} else {
deferred.reject(new Error('pai cluster config format error, please set password or token!'));
}
// Get PAI authentication token
await this.updatePaiToken();
deferred.resolve(); deferred.resolve();
break; break;
...@@ -483,8 +489,7 @@ class PAITrainingService implements TrainingService { ...@@ -483,8 +489,7 @@ class PAITrainingService implements TrainingService {
request(submitJobRequest, (error: Error, response: request.Response, body: any) => { request(submitJobRequest, (error: Error, response: request.Response, body: any) => {
if ((error !== undefined && error !== null) || response.statusCode >= 400) { if ((error !== undefined && error !== null) || response.statusCode >= 400) {
const errorMessage : string = (error !== undefined && error !== null) ? error.message : const errorMessage : string = (error !== undefined && error !== null) ? error.message :
`Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${response.body}`; `Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${response.body.message}`;
this.log.error(errorMessage);
trialJobDetail.status = 'FAILED'; trialJobDetail.status = 'FAILED';
deferred.resolve(true); deferred.resolve(true);
} else { } else {
...@@ -498,13 +503,15 @@ class PAITrainingService implements TrainingService { ...@@ -498,13 +503,15 @@ class PAITrainingService implements TrainingService {
private async statusCheckingLoop(): Promise<void> { private async statusCheckingLoop(): Promise<void> {
while (!this.stopping) { while (!this.stopping) {
try { if(this.paiClusterConfig && this.paiClusterConfig.passWord) {
await this.updatePaiToken(); try {
} catch (error) { await this.updatePaiToken();
this.log.error(`${error}`); } catch (error) {
//only throw error when initlize paiToken first time this.log.error(`${error}`);
if (this.paiToken === undefined) { //only throw error when initlize paiToken first time
throw new Error(error); if (this.paiToken === undefined) {
throw new Error(error);
}
} }
} }
await this.paiJobCollector.retrieveTrialStatus(this.paiToken, this.paiClusterConfig); await this.paiJobCollector.retrieveTrialStatus(this.paiToken, this.paiClusterConfig);
......
...@@ -265,11 +265,15 @@ pai_trial_schema = { ...@@ -265,11 +265,15 @@ pai_trial_schema = {
} }
pai_config_schema = { pai_config_schema = {
'paiConfig':{ 'paiConfig': Or({
'userName': setType('userName', str), 'userName': setType('userName', str),
'passWord': setType('passWord', str), 'passWord': setType('passWord', str),
'host': setType('host', str) 'host': setType('host', str)
} }, {
'userName': setType('userName', str),
'token': setType('token', str),
'host': setType('host', str)
})
} }
kubeflow_trial_schema = { kubeflow_trial_schema = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment