Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
543239c6
Unverified
Commit
543239c6
authored
Dec 12, 2019
by
SparkSnail
Committed by
GitHub
Dec 12, 2019
Browse files
Merge pull request #220 from microsoft/master
merge master
parents
32efaa36
659480f2
Changes
94
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
371 additions
and
409 deletions
+371
-409
src/nni_manager/training_service/pai/paiJobInfoCollector.ts
src/nni_manager/training_service/pai/paiJobInfoCollector.ts
+7
-12
src/nni_manager/training_service/pai/paiJobRestServer.ts
src/nni_manager/training_service/pai/paiJobRestServer.ts
+2
-3
src/nni_manager/training_service/pai/paiTrainingService.ts
src/nni_manager/training_service/pai/paiTrainingService.ts
+19
-33
src/nni_manager/training_service/pai/paiTrialConfig.ts
src/nni_manager/training_service/pai/paiTrialConfig.ts
+1
-1
src/nni_manager/training_service/remote_machine/gpuScheduler.ts
...i_manager/training_service/remote_machine/gpuScheduler.ts
+11
-14
src/nni_manager/training_service/remote_machine/remoteMachineData.ts
...ager/training_service/remote_machine/remoteMachineData.ts
+10
-11
src/nni_manager/training_service/remote_machine/remoteMachineJobRestServer.ts
...ning_service/remote_machine/remoteMachineJobRestServer.ts
+2
-3
src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts
...ng_service/remote_machine/remoteMachineTrainingService.ts
+24
-38
src/nni_manager/training_service/remote_machine/sshClientUtility.ts
...nager/training_service/remote_machine/sshClientUtility.ts
+37
-41
src/nni_manager/yarn.lock
src/nni_manager/yarn.lock
+8
-84
src/sdk/pynni/nni/batch_tuner/batch_tuner.py
src/sdk/pynni/nni/batch_tuner/batch_tuner.py
+9
-7
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
+211
-143
src/sdk/pynni/nni/msg_dispatcher_base.py
src/sdk/pynni/nni/msg_dispatcher_base.py
+19
-11
src/sdk/pynni/tests/test_compressor.py
src/sdk/pynni/tests/test_compressor.py
+11
-8
No files found.
src/nni_manager/training_service/pai/paiJobInfoCollector.ts
View file @
543239c6
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
'
use strict
'
;
'
use strict
'
;
// tslint:disable-next-line:no-implicit-dependencies
import
*
as
request
from
'
request
'
;
import
*
as
request
from
'
request
'
;
import
{
Deferred
}
from
'
ts-deferred
'
;
import
{
Deferred
}
from
'
ts-deferred
'
;
import
{
NNIError
,
NNIErrorNames
}
from
'
../../common/errors
'
;
import
{
NNIError
,
NNIErrorNames
}
from
'
../../common/errors
'
;
...
@@ -16,10 +15,10 @@ import { PAITrialJobDetail } from './paiData';
...
@@ -16,10 +15,10 @@ import { PAITrialJobDetail } from './paiData';
* Collector PAI jobs info from PAI cluster, and update pai job status locally
* Collector PAI jobs info from PAI cluster, and update pai job status locally
*/
*/
export
class
PAIJobInfoCollector
{
export
class
PAIJobInfoCollector
{
private
readonly
trialJobsMap
:
Map
<
string
,
PAITrialJobDetail
>
;
private
readonly
trialJobsMap
:
Map
<
string
,
PAITrialJobDetail
>
;
private
readonly
log
:
Logger
=
getLogger
();
private
readonly
log
:
Logger
=
getLogger
();
private
readonly
statusesNeedToCheck
:
TrialJobStatus
[];
private
readonly
statusesNeedToCheck
:
TrialJobStatus
[];
private
readonly
finalStatuses
:
TrialJobStatus
[];
private
readonly
finalStatuses
:
TrialJobStatus
[];
constructor
(
jobMap
:
Map
<
string
,
PAITrialJobDetail
>
)
{
constructor
(
jobMap
:
Map
<
string
,
PAITrialJobDetail
>
)
{
this
.
trialJobsMap
=
jobMap
;
this
.
trialJobsMap
=
jobMap
;
...
@@ -27,12 +26,12 @@ export class PAIJobInfoCollector {
...
@@ -27,12 +26,12 @@ export class PAIJobInfoCollector {
this
.
finalStatuses
=
[
'
SUCCEEDED
'
,
'
FAILED
'
,
'
USER_CANCELED
'
,
'
SYS_CANCELED
'
,
'
EARLY_STOPPED
'
];
this
.
finalStatuses
=
[
'
SUCCEEDED
'
,
'
FAILED
'
,
'
USER_CANCELED
'
,
'
SYS_CANCELED
'
,
'
EARLY_STOPPED
'
];
}
}
public
async
retrieveTrialStatus
(
paiToken
?
:
string
,
paiClusterConfig
?:
PAIClusterConfig
)
:
Promise
<
void
>
{
public
async
retrieveTrialStatus
(
paiToken
?
:
string
,
paiClusterConfig
?:
PAIClusterConfig
):
Promise
<
void
>
{
if
(
paiClusterConfig
===
undefined
||
paiToken
===
undefined
)
{
if
(
paiClusterConfig
===
undefined
||
paiToken
===
undefined
)
{
return
Promise
.
resolve
();
return
Promise
.
resolve
();
}
}
const
updatePaiTrialJobs
:
Promise
<
void
>
[]
=
[];
const
updatePaiTrialJobs
:
Promise
<
void
>
[]
=
[];
for
(
const
[
trialJobId
,
paiTrialJob
]
of
this
.
trialJobsMap
)
{
for
(
const
[
trialJobId
,
paiTrialJob
]
of
this
.
trialJobsMap
)
{
if
(
paiTrialJob
===
undefined
)
{
if
(
paiTrialJob
===
undefined
)
{
throw
new
NNIError
(
NNIErrorNames
.
NOT_FOUND
,
`trial job id
${
trialJobId
}
not found`
);
throw
new
NNIError
(
NNIErrorNames
.
NOT_FOUND
,
`trial job id
${
trialJobId
}
not found`
);
...
@@ -43,9 +42,8 @@ export class PAIJobInfoCollector {
...
@@ -43,9 +42,8 @@ export class PAIJobInfoCollector {
await
Promise
.
all
(
updatePaiTrialJobs
);
await
Promise
.
all
(
updatePaiTrialJobs
);
}
}
private
getSinglePAITrialJobInfo
(
paiTrialJob
:
PAITrialJobDetail
,
paiToken
:
string
,
paiClusterConfig
:
PAIClusterConfig
)
private
getSinglePAITrialJobInfo
(
paiTrialJob
:
PAITrialJobDetail
,
paiToken
:
string
,
paiClusterConfig
:
PAIClusterConfig
):
Promise
<
void
>
{
:
Promise
<
void
>
{
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
if
(
!
this
.
statusesNeedToCheck
.
includes
(
paiTrialJob
.
status
))
{
if
(
!
this
.
statusesNeedToCheck
.
includes
(
paiTrialJob
.
status
))
{
deferred
.
resolve
();
deferred
.
resolve
();
...
@@ -55,7 +53,6 @@ export class PAIJobInfoCollector {
...
@@ -55,7 +53,6 @@ export class PAIJobInfoCollector {
// Rest call to get PAI job info and update status
// Rest call to get PAI job info and update status
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const
getJobInfoRequest
:
request
.
Options
=
{
const
getJobInfoRequest
:
request
.
Options
=
{
// tslint:disable-next-line:no-http-string
uri
:
`http://
${
paiClusterConfig
.
host
}
/rest-server/api/v1/user/
${
paiClusterConfig
.
userName
}
/jobs/
${
paiTrialJob
.
paiJobName
}
`
,
uri
:
`http://
${
paiClusterConfig
.
host
}
/rest-server/api/v1/user/
${
paiClusterConfig
.
userName
}
/jobs/
${
paiTrialJob
.
paiJobName
}
`
,
method
:
'
GET
'
,
method
:
'
GET
'
,
json
:
true
,
json
:
true
,
...
@@ -65,7 +62,6 @@ export class PAIJobInfoCollector {
...
@@ -65,7 +62,6 @@ export class PAIJobInfoCollector {
}
}
};
};
// tslint:disable: no-unsafe-any no-any cyclomatic-complexity
//TODO : pass in request timeout param?
//TODO : pass in request timeout param?
request
(
getJobInfoRequest
,
(
error
:
Error
,
response
:
request
.
Response
,
body
:
any
)
=>
{
request
(
getJobInfoRequest
,
(
error
:
Error
,
response
:
request
.
Response
,
body
:
any
)
=>
{
if
((
error
!==
undefined
&&
error
!==
null
)
||
response
.
statusCode
>=
500
)
{
if
((
error
!==
undefined
&&
error
!==
null
)
||
response
.
statusCode
>=
500
)
{
...
@@ -129,5 +125,4 @@ export class PAIJobInfoCollector {
...
@@ -129,5 +125,4 @@ export class PAIJobInfoCollector {
return
deferred
.
promise
;
return
deferred
.
promise
;
}
}
// tslint:enable: no-unsafe-any no-any
}
}
src/nni_manager/training_service/pai/paiJobRestServer.ts
View file @
543239c6
...
@@ -24,7 +24,7 @@ export class PAIJobRestServer extends ClusterJobRestServer {
...
@@ -24,7 +24,7 @@ export class PAIJobRestServer extends ClusterJobRestServer {
private
parameterFileMetaList
:
ParameterFileMeta
[]
=
[];
private
parameterFileMetaList
:
ParameterFileMeta
[]
=
[];
@
Inject
@
Inject
private
readonly
paiTrainingService
:
PAITrainingService
;
private
readonly
paiTrainingService
:
PAITrainingService
;
/**
/**
* constructor to provide NNIRestServer's own rest property, e.g. port
* constructor to provide NNIRestServer's own rest property, e.g. port
...
@@ -34,8 +34,7 @@ export class PAIJobRestServer extends ClusterJobRestServer {
...
@@ -34,8 +34,7 @@ export class PAIJobRestServer extends ClusterJobRestServer {
this
.
paiTrainingService
=
component
.
get
(
PAITrainingService
);
this
.
paiTrainingService
=
component
.
get
(
PAITrainingService
);
}
}
// tslint:disable-next-line:no-any
protected
handleTrialMetrics
(
jobId
:
string
,
metrics
:
any
[]):
void
{
protected
handleTrialMetrics
(
jobId
:
string
,
metrics
:
any
[])
:
void
{
// Split metrics array into single metric, then emit
// Split metrics array into single metric, then emit
// Warning: If not split metrics into single ones, the behavior will be UNKNOWN
// Warning: If not split metrics into single ones, the behavior will be UNKNOWN
for
(
const
singleMetric
of
metrics
)
{
for
(
const
singleMetric
of
metrics
)
{
...
...
src/nni_manager/training_service/pai/paiTrainingService.ts
View file @
543239c6
...
@@ -3,17 +3,14 @@
...
@@ -3,17 +3,14 @@
'
use strict
'
;
'
use strict
'
;
import
*
as
cpp
from
'
child-process-promise
'
;
import
*
as
fs
from
'
fs
'
;
import
*
as
fs
from
'
fs
'
;
import
*
as
path
from
'
path
'
;
import
*
as
path
from
'
path
'
;
// tslint:disable-next-line:no-implicit-dependencies
import
*
as
request
from
'
request
'
;
import
*
as
request
from
'
request
'
;
import
*
as
component
from
'
../../common/component
'
;
import
*
as
component
from
'
../../common/component
'
;
import
{
EventEmitter
}
from
'
events
'
;
import
{
EventEmitter
}
from
'
events
'
;
import
{
Deferred
}
from
'
ts-deferred
'
;
import
{
Deferred
}
from
'
ts-deferred
'
;
import
{
String
}
from
'
typescript-string-operations
'
;
import
{
String
}
from
'
typescript-string-operations
'
;
import
{
MethodNotImplementedError
}
from
'
../../common/errors
'
;
import
{
getExperimentId
}
from
'
../../common/experimentStartupInfo
'
;
import
{
getExperimentId
}
from
'
../../common/experimentStartupInfo
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
import
{
...
@@ -47,13 +44,12 @@ class PAITrainingService implements TrainingService {
...
@@ -47,13 +44,12 @@ class PAITrainingService implements TrainingService {
private
paiClusterConfig
?:
PAIClusterConfig
;
private
paiClusterConfig
?:
PAIClusterConfig
;
private
readonly
jobQueue
:
string
[];
private
readonly
jobQueue
:
string
[];
private
stopping
:
boolean
=
false
;
private
stopping
:
boolean
=
false
;
// tslint:disable-next-line:no-any
private
hdfsClient
:
any
;
private
hdfsClient
:
any
;
private
paiToken
?
:
string
;
private
paiToken
?
:
string
;
private
paiTokenUpdateTime
?:
number
;
private
paiTokenUpdateTime
?:
number
;
private
readonly
paiTokenUpdateInterval
:
number
;
private
readonly
paiTokenUpdateInterval
:
number
;
private
readonly
experimentId
!
:
string
;
private
readonly
experimentId
!
:
string
;
private
readonly
paiJobCollector
:
PAIJobInfoCollector
;
private
readonly
paiJobCollector
:
PAIJobInfoCollector
;
private
paiRestServerPort
?:
number
;
private
paiRestServerPort
?:
number
;
private
nniManagerIpConfig
?:
NNIManagerIpConfig
;
private
nniManagerIpConfig
?:
NNIManagerIpConfig
;
private
copyExpCodeDirPromise
?:
Promise
<
void
>
;
private
copyExpCodeDirPromise
?:
Promise
<
void
>
;
...
@@ -126,7 +122,7 @@ class PAITrainingService implements TrainingService {
...
@@ -126,7 +122,7 @@ class PAITrainingService implements TrainingService {
if
(
this
.
paiClusterConfig
===
undefined
)
{
if
(
this
.
paiClusterConfig
===
undefined
)
{
throw
new
Error
(
`paiClusterConfig not initialized!`
);
throw
new
Error
(
`paiClusterConfig not initialized!`
);
}
}
const
deferred
:
Deferred
<
PAITrialJobDetail
>
=
new
Deferred
<
PAITrialJobDetail
>
();
const
deferred
:
Deferred
<
PAITrialJobDetail
>
=
new
Deferred
<
PAITrialJobDetail
>
();
this
.
log
.
info
(
`submitTrialJob: form:
${
JSON
.
stringify
(
form
)}
`
);
this
.
log
.
info
(
`submitTrialJob: form:
${
JSON
.
stringify
(
form
)}
`
);
...
@@ -137,7 +133,7 @@ class PAITrainingService implements TrainingService {
...
@@ -137,7 +133,7 @@ class PAITrainingService implements TrainingService {
const
hdfsCodeDir
:
string
=
HDFSClientUtility
.
getHdfsTrialWorkDir
(
this
.
paiClusterConfig
.
userName
,
trialJobId
);
const
hdfsCodeDir
:
string
=
HDFSClientUtility
.
getHdfsTrialWorkDir
(
this
.
paiClusterConfig
.
userName
,
trialJobId
);
const
hdfsOutputDir
:
string
=
unixPathJoin
(
hdfsCodeDir
,
'
nnioutput
'
);
const
hdfsOutputDir
:
string
=
unixPathJoin
(
hdfsCodeDir
,
'
nnioutput
'
);
const
hdfsLogPath
:
string
=
String
.
Format
(
const
hdfsLogPath
:
string
=
String
.
Format
(
PAI_LOG_PATH_FORMAT
,
PAI_LOG_PATH_FORMAT
,
this
.
paiClusterConfig
.
host
,
this
.
paiClusterConfig
.
host
,
hdfsOutputDir
hdfsOutputDir
...
@@ -173,10 +169,9 @@ class PAITrainingService implements TrainingService {
...
@@ -173,10 +169,9 @@ class PAITrainingService implements TrainingService {
return
true
;
return
true
;
}
}
// tslint:disable:no-http-string
public
cancelTrialJob
(
trialJobId
:
string
,
isEarlyStopped
:
boolean
=
false
):
Promise
<
void
>
{
public
cancelTrialJob
(
trialJobId
:
string
,
isEarlyStopped
:
boolean
=
false
):
Promise
<
void
>
{
const
trialJobDetail
:
PAITrialJobDetail
|
undefined
=
this
.
trialJobsMap
.
get
(
trialJobId
);
const
trialJobDetail
:
PAITrialJobDetail
|
undefined
=
this
.
trialJobsMap
.
get
(
trialJobId
);
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
if
(
trialJobDetail
===
undefined
)
{
if
(
trialJobDetail
===
undefined
)
{
this
.
log
.
error
(
`cancelTrialJob: trial job id
${
trialJobId
}
not found`
);
this
.
log
.
error
(
`cancelTrialJob: trial job id
${
trialJobId
}
not found`
);
...
@@ -205,7 +200,6 @@ class PAITrainingService implements TrainingService {
...
@@ -205,7 +200,6 @@ class PAITrainingService implements TrainingService {
// Set trialjobDetail's early stopped field, to mark the job's cancellation source
// Set trialjobDetail's early stopped field, to mark the job's cancellation source
trialJobDetail
.
isEarlyStopped
=
isEarlyStopped
;
trialJobDetail
.
isEarlyStopped
=
isEarlyStopped
;
// tslint:disable-next-line:no-any
request
(
stopJobRequest
,
(
error
:
Error
,
response
:
request
.
Response
,
body
:
any
)
=>
{
request
(
stopJobRequest
,
(
error
:
Error
,
response
:
request
.
Response
,
body
:
any
)
=>
{
if
((
error
!==
undefined
&&
error
!==
null
)
||
response
.
statusCode
>=
400
)
{
if
((
error
!==
undefined
&&
error
!==
null
)
||
response
.
statusCode
>=
400
)
{
this
.
log
.
error
(
`PAI Training service: stop trial
${
trialJobId
}
to PAI Cluster failed!`
);
this
.
log
.
error
(
`PAI Training service: stop trial
${
trialJobId
}
to PAI Cluster failed!`
);
...
@@ -219,10 +213,8 @@ class PAITrainingService implements TrainingService {
...
@@ -219,10 +213,8 @@ class PAITrainingService implements TrainingService {
return
deferred
.
promise
;
return
deferred
.
promise
;
}
}
// tslint:disable: no-unsafe-any no-any
// tslint:disable-next-line:max-func-body-length
public
async
setClusterMetadata
(
key
:
string
,
value
:
string
):
Promise
<
void
>
{
public
async
setClusterMetadata
(
key
:
string
,
value
:
string
):
Promise
<
void
>
{
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
switch
(
key
)
{
switch
(
key
)
{
case
TrialConfigMetadataKey
.
NNI_MANAGER_IP
:
case
TrialConfigMetadataKey
.
NNI_MANAGER_IP
:
...
@@ -300,10 +292,9 @@ class PAITrainingService implements TrainingService {
...
@@ -300,10 +292,9 @@ class PAITrainingService implements TrainingService {
return
deferred
.
promise
;
return
deferred
.
promise
;
}
}
// tslint:enable: no-unsafe-any
public
getClusterMetadata
(
key
:
string
):
Promise
<
string
>
{
public
getClusterMetadata
(
key
:
string
):
Promise
<
string
>
{
const
deferred
:
Deferred
<
string
>
=
new
Deferred
<
string
>
();
const
deferred
:
Deferred
<
string
>
=
new
Deferred
<
string
>
();
deferred
.
resolve
();
deferred
.
resolve
();
...
@@ -314,14 +305,13 @@ class PAITrainingService implements TrainingService {
...
@@ -314,14 +305,13 @@ class PAITrainingService implements TrainingService {
this
.
log
.
info
(
'
Stopping PAI training service...
'
);
this
.
log
.
info
(
'
Stopping PAI training service...
'
);
this
.
stopping
=
true
;
this
.
stopping
=
true
;
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
restServer
:
PAIJobRestServer
=
component
.
get
(
PAIJobRestServer
);
const
restServer
:
PAIJobRestServer
=
component
.
get
(
PAIJobRestServer
);
try
{
try
{
await
restServer
.
stop
();
await
restServer
.
stop
();
deferred
.
resolve
();
deferred
.
resolve
();
this
.
log
.
info
(
'
PAI Training service rest server stopped successfully.
'
);
this
.
log
.
info
(
'
PAI Training service rest server stopped successfully.
'
);
}
catch
(
error
)
{
}
catch
(
error
)
{
// tslint:disable-next-line: no-unsafe-any
this
.
log
.
error
(
`PAI Training service rest server stopped failed, error:
${
error
.
message
}
`
);
this
.
log
.
error
(
`PAI Training service rest server stopped failed, error:
${
error
.
message
}
`
);
deferred
.
reject
(
error
);
deferred
.
reject
(
error
);
}
}
...
@@ -329,13 +319,12 @@ class PAITrainingService implements TrainingService {
...
@@ -329,13 +319,12 @@ class PAITrainingService implements TrainingService {
return
deferred
.
promise
;
return
deferred
.
promise
;
}
}
public
get
MetricsEmitter
()
:
EventEmitter
{
public
get
MetricsEmitter
():
EventEmitter
{
return
this
.
metricsEmitter
;
return
this
.
metricsEmitter
;
}
}
// tslint:disable-next-line:max-func-body-length
private
async
submitTrialJobToPAI
(
trialJobId
:
string
):
Promise
<
boolean
>
{
private
async
submitTrialJobToPAI
(
trialJobId
:
string
):
Promise
<
boolean
>
{
const
deferred
:
Deferred
<
boolean
>
=
new
Deferred
<
boolean
>
();
const
deferred
:
Deferred
<
boolean
>
=
new
Deferred
<
boolean
>
();
const
trialJobDetail
:
PAITrialJobDetail
|
undefined
=
this
.
trialJobsMap
.
get
(
trialJobId
);
const
trialJobDetail
:
PAITrialJobDetail
|
undefined
=
this
.
trialJobsMap
.
get
(
trialJobId
);
if
(
trialJobDetail
===
undefined
)
{
if
(
trialJobDetail
===
undefined
)
{
...
@@ -372,7 +361,7 @@ class PAITrainingService implements TrainingService {
...
@@ -372,7 +361,7 @@ class PAITrainingService implements TrainingService {
//create tmp trial working folder locally.
//create tmp trial working folder locally.
await
execMkdir
(
trialLocalTempFolder
);
await
execMkdir
(
trialLocalTempFolder
);
const
runScriptContent
:
string
=
CONTAINER_INSTALL_NNI_SHELL_FORMAT
;
const
runScriptContent
:
string
=
CONTAINER_INSTALL_NNI_SHELL_FORMAT
;
// Write NNI installation file to local tmp files
// Write NNI installation file to local tmp files
await
fs
.
promises
.
writeFile
(
path
.
join
(
trialLocalTempFolder
,
'
install_nni.sh
'
),
runScriptContent
,
{
encoding
:
'
utf8
'
});
await
fs
.
promises
.
writeFile
(
path
.
join
(
trialLocalTempFolder
,
'
install_nni.sh
'
),
runScriptContent
,
{
encoding
:
'
utf8
'
});
...
@@ -385,10 +374,9 @@ class PAITrainingService implements TrainingService {
...
@@ -385,10 +374,9 @@ class PAITrainingService implements TrainingService {
}
}
const
hdfsCodeDir
:
string
=
HDFSClientUtility
.
getHdfsTrialWorkDir
(
this
.
paiClusterConfig
.
userName
,
trialJobId
);
const
hdfsCodeDir
:
string
=
HDFSClientUtility
.
getHdfsTrialWorkDir
(
this
.
paiClusterConfig
.
userName
,
trialJobId
);
const
hdfsOutputDir
:
string
=
unixPathJoin
(
hdfsCodeDir
,
'
nnioutput
'
);
const
hdfsOutputDir
:
string
=
unixPathJoin
(
hdfsCodeDir
,
'
nnioutput
'
);
// tslint:disable-next-line: strict-boolean-expressions
const
nniManagerIp
:
string
=
this
.
nniManagerIpConfig
?
this
.
nniManagerIpConfig
.
nniManagerIp
:
getIPV4Address
();
const
nniManagerIp
:
string
=
this
.
nniManagerIpConfig
?
this
.
nniManagerIpConfig
.
nniManagerIp
:
getIPV4Address
();
const
version
:
string
=
this
.
versionCheck
?
await
getVersion
()
:
''
;
const
version
:
string
=
this
.
versionCheck
?
await
getVersion
()
:
''
;
const
nniPaiTrialCommand
:
string
=
String
.
Format
(
const
nniPaiTrialCommand
:
string
=
String
.
Format
(
PAI_TRIAL_COMMAND_FORMAT
,
PAI_TRIAL_COMMAND_FORMAT
,
// PAI will copy job's codeDir into /root directory
// PAI will copy job's codeDir into /root directory
`$PWD/
${
trialJobId
}
`
,
`$PWD/
${
trialJobId
}
`
,
...
@@ -409,9 +397,8 @@ class PAITrainingService implements TrainingService {
...
@@ -409,9 +397,8 @@ class PAITrainingService implements TrainingService {
)
)
.
replace
(
/
\r\n
|
\n
|
\r
/gm
,
''
);
.
replace
(
/
\r\n
|
\n
|
\r
/gm
,
''
);
// tslint:disable-next-line:no-console
this
.
log
.
info
(
`nniPAItrial command is
${
nniPaiTrialCommand
.
trim
()}
`
);
this
.
log
.
info
(
`nniPAItrial command is
${
nniPaiTrialCommand
.
trim
()}
`
);
const
paiTaskRoles
:
PAITaskRole
[]
=
[
const
paiTaskRoles
:
PAITaskRole
[]
=
[
new
PAITaskRole
(
new
PAITaskRole
(
`nni_trail_
${
trialJobId
}
`
,
`nni_trail_
${
trialJobId
}
`
,
// Task role number
// Task role number
...
@@ -431,7 +418,7 @@ class PAITrainingService implements TrainingService {
...
@@ -431,7 +418,7 @@ class PAITrainingService implements TrainingService {
)
)
];
];
const
paiJobConfig
:
PAIJobConfig
=
new
PAIJobConfig
(
const
paiJobConfig
:
PAIJobConfig
=
new
PAIJobConfig
(
// Job name
// Job name
trialJobDetail
.
paiJobName
,
trialJobDetail
.
paiJobName
,
// Docker image
// Docker image
...
@@ -451,7 +438,7 @@ class PAITrainingService implements TrainingService {
...
@@ -451,7 +438,7 @@ class PAITrainingService implements TrainingService {
await
HDFSClientUtility
.
copyDirectoryToHdfs
(
trialLocalTempFolder
,
hdfsCodeDir
,
this
.
hdfsClient
);
await
HDFSClientUtility
.
copyDirectoryToHdfs
(
trialLocalTempFolder
,
hdfsCodeDir
,
this
.
hdfsClient
);
}
catch
(
error
)
{
}
catch
(
error
)
{
this
.
log
.
error
(
`PAI Training service: copy
${
this
.
paiTrialConfig
.
codeDir
}
to HDFS
${
hdfsCodeDir
}
failed, error is
${
error
}
`
);
this
.
log
.
error
(
`PAI Training service: copy
${
this
.
paiTrialConfig
.
codeDir
}
to HDFS
${
hdfsCodeDir
}
failed, error is
${
error
}
`
);
trialJobDetail
.
status
=
'
FAILED
'
;
trialJobDetail
.
status
=
'
FAILED
'
;
// eslint-disable-line require-atomic-updates
deferred
.
resolve
(
true
);
deferred
.
resolve
(
true
);
return
deferred
.
promise
;
return
deferred
.
promise
;
...
@@ -469,10 +456,9 @@ class PAITrainingService implements TrainingService {
...
@@ -469,10 +456,9 @@ class PAITrainingService implements TrainingService {
Authorization
:
`Bearer
${
this
.
paiToken
}
`
Authorization
:
`Bearer
${
this
.
paiToken
}
`
}
}
};
};
// tslint:disable:no-any no-unsafe-any
request
(
submitJobRequest
,
(
error
:
Error
,
response
:
request
.
Response
,
body
:
any
)
=>
{
request
(
submitJobRequest
,
(
error
:
Error
,
response
:
request
.
Response
,
body
:
any
)
=>
{
if
((
error
!==
undefined
&&
error
!==
null
)
||
response
.
statusCode
>=
400
)
{
if
((
error
!==
undefined
&&
error
!==
null
)
||
response
.
statusCode
>=
400
)
{
const
errorMessage
:
string
=
(
error
!==
undefined
&&
error
!==
null
)
?
error
.
message
:
const
errorMessage
:
string
=
(
error
!==
undefined
&&
error
!==
null
)
?
error
.
message
:
`Submit trial
${
trialJobId
}
failed, http code:
${
response
.
statusCode
}
, http body:
${
response
.
body
.
message
}
`
;
`Submit trial
${
trialJobId
}
failed, http code:
${
response
.
statusCode
}
, http body:
${
response
.
body
.
message
}
`
;
trialJobDetail
.
status
=
'
FAILED
'
;
trialJobDetail
.
status
=
'
FAILED
'
;
deferred
.
resolve
(
true
);
deferred
.
resolve
(
true
);
...
@@ -527,7 +513,7 @@ class PAITrainingService implements TrainingService {
...
@@ -527,7 +513,7 @@ class PAITrainingService implements TrainingService {
* Update pai token by the interval time or initialize the pai token
* Update pai token by the interval time or initialize the pai token
*/
*/
private
async
updatePaiToken
():
Promise
<
void
>
{
private
async
updatePaiToken
():
Promise
<
void
>
{
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
currentTime
:
number
=
new
Date
().
getTime
();
const
currentTime
:
number
=
new
Date
().
getTime
();
//If pai token initialized and not reach the interval time, do not update
//If pai token initialized and not reach the interval time, do not update
...
@@ -603,7 +589,7 @@ class PAITrainingService implements TrainingService {
...
@@ -603,7 +589,7 @@ class PAITrainingService implements TrainingService {
}
}
private
postParameterFileMeta
(
parameterFileMeta
:
ParameterFileMeta
):
Promise
<
void
>
{
private
postParameterFileMeta
(
parameterFileMeta
:
ParameterFileMeta
):
Promise
<
void
>
{
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
restServer
:
PAIJobRestServer
=
component
.
get
(
PAIJobRestServer
);
const
restServer
:
PAIJobRestServer
=
component
.
get
(
PAIJobRestServer
);
const
req
:
request
.
Options
=
{
const
req
:
request
.
Options
=
{
uri
:
`
${
restServer
.
endPoint
}${
restServer
.
apiRootUrl
}
/parameter-file-meta`
,
uri
:
`
${
restServer
.
endPoint
}${
restServer
.
apiRootUrl
}
/parameter-file-meta`
,
...
...
src/nni_manager/training_service/pai/paiTrialConfig.ts
View file @
543239c6
...
@@ -15,7 +15,7 @@ export class PAITrialConfig extends TrialConfig {
...
@@ -15,7 +15,7 @@ export class PAITrialConfig extends TrialConfig {
public
readonly
dataDir
:
string
;
public
readonly
dataDir
:
string
;
public
readonly
outputDir
:
string
;
public
readonly
outputDir
:
string
;
constructor
(
command
:
string
,
codeDir
:
string
,
gpuNum
:
number
,
cpuNum
:
number
,
memoryMB
:
number
,
constructor
(
command
:
string
,
codeDir
:
string
,
gpuNum
:
number
,
cpuNum
:
number
,
memoryMB
:
number
,
image
:
string
,
dataDir
:
string
,
outputDir
:
string
)
{
image
:
string
,
dataDir
:
string
,
outputDir
:
string
)
{
super
(
command
,
codeDir
,
gpuNum
);
super
(
command
,
codeDir
,
gpuNum
);
this
.
cpuNum
=
cpuNum
;
this
.
cpuNum
=
cpuNum
;
...
...
src/nni_manager/training_service/remote_machine/gpuScheduler.ts
View file @
543239c6
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
import
*
as
assert
from
'
assert
'
;
import
*
as
assert
from
'
assert
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
TrialJobDetail
}
from
'
../../common/trainingService
'
;
import
{
randomSelect
}
from
'
../../common/utils
'
;
import
{
randomSelect
}
from
'
../../common/utils
'
;
import
{
GPUInfo
}
from
'
../common/gpuData
'
;
import
{
GPUInfo
}
from
'
../common/gpuData
'
;
import
{
import
{
...
@@ -19,7 +18,7 @@ type SCHEDULE_POLICY_NAME = 'random' | 'round-robin';
...
@@ -19,7 +18,7 @@ type SCHEDULE_POLICY_NAME = 'random' | 'round-robin';
*/
*/
export
class
GPUScheduler
{
export
class
GPUScheduler
{
private
readonly
machineSSHClientMap
:
Map
<
RemoteMachineMeta
,
SSHClientManager
>
;
private
readonly
machineSSHClientMap
:
Map
<
RemoteMachineMeta
,
SSHClientManager
>
;
private
readonly
log
:
Logger
=
getLogger
();
private
readonly
log
:
Logger
=
getLogger
();
private
readonly
policyName
:
SCHEDULE_POLICY_NAME
=
'
round-robin
'
;
private
readonly
policyName
:
SCHEDULE_POLICY_NAME
=
'
round-robin
'
;
private
roundRobinIndex
:
number
=
0
;
private
roundRobinIndex
:
number
=
0
;
...
@@ -29,7 +28,7 @@ export class GPUScheduler {
...
@@ -29,7 +28,7 @@ export class GPUScheduler {
* Constructor
* Constructor
* @param machineSSHClientMap map from remote machine to sshClient
* @param machineSSHClientMap map from remote machine to sshClient
*/
*/
constructor
(
machineSSHClientMap
:
Map
<
RemoteMachineMeta
,
SSHClientManager
>
)
{
constructor
(
machineSSHClientMap
:
Map
<
RemoteMachineMeta
,
SSHClientManager
>
)
{
assert
(
machineSSHClientMap
.
size
>
0
);
assert
(
machineSSHClientMap
.
size
>
0
);
this
.
machineSSHClientMap
=
machineSSHClientMap
;
this
.
machineSSHClientMap
=
machineSSHClientMap
;
this
.
configuredRMs
=
Array
.
from
(
machineSSHClientMap
.
keys
());
this
.
configuredRMs
=
Array
.
from
(
machineSSHClientMap
.
keys
());
...
@@ -39,7 +38,7 @@ export class GPUScheduler {
...
@@ -39,7 +38,7 @@ export class GPUScheduler {
* Schedule a machine according to the constraints (requiredGPUNum)
* Schedule a machine according to the constraints (requiredGPUNum)
* @param requiredGPUNum required GPU number
* @param requiredGPUNum required GPU number
*/
*/
public
scheduleMachine
(
requiredGPUNum
:
number
|
undefined
,
trialJobDetail
:
RemoteMachineTrialJobDetail
)
:
RemoteMachineScheduleResult
{
public
scheduleMachine
(
requiredGPUNum
:
number
|
undefined
,
trialJobDetail
:
RemoteMachineTrialJobDetail
):
RemoteMachineScheduleResult
{
if
(
requiredGPUNum
===
undefined
)
{
if
(
requiredGPUNum
===
undefined
)
{
requiredGPUNum
=
0
;
requiredGPUNum
=
0
;
}
}
...
@@ -48,7 +47,7 @@ export class GPUScheduler {
...
@@ -48,7 +47,7 @@ export class GPUScheduler {
assert
(
allRMs
.
length
>
0
);
assert
(
allRMs
.
length
>
0
);
// Step 1: Check if required GPU number not exceeds the total GPU number in all machines
// Step 1: Check if required GPU number not exceeds the total GPU number in all machines
const
eligibleRM
:
RemoteMachineMeta
[]
=
allRMs
.
filter
((
rmMeta
:
RemoteMachineMeta
)
=>
const
eligibleRM
:
RemoteMachineMeta
[]
=
allRMs
.
filter
((
rmMeta
:
RemoteMachineMeta
)
=>
rmMeta
.
gpuSummary
===
undefined
||
requiredGPUNum
===
0
||
(
requiredGPUNum
!==
undefined
&&
rmMeta
.
gpuSummary
.
gpuCount
>=
requiredGPUNum
));
rmMeta
.
gpuSummary
===
undefined
||
requiredGPUNum
===
0
||
(
requiredGPUNum
!==
undefined
&&
rmMeta
.
gpuSummary
.
gpuCount
>=
requiredGPUNum
));
if
(
eligibleRM
.
length
===
0
)
{
if
(
eligibleRM
.
length
===
0
)
{
// If the required gpu number exceeds the upper limit of all machine's GPU number
// If the required gpu number exceeds the upper limit of all machine's GPU number
...
@@ -134,8 +133,8 @@ export class GPUScheduler {
...
@@ -134,8 +133,8 @@ export class GPUScheduler {
* @param availableGPUMap available GPU resource filled by this detection
* @param availableGPUMap available GPU resource filled by this detection
* @returns Available GPU number on this remote machine
* @returns Available GPU number on this remote machine
*/
*/
private
gpuResourceDetection
()
:
Map
<
RemoteMachineMeta
,
GPUInfo
[]
>
{
private
gpuResourceDetection
():
Map
<
RemoteMachineMeta
,
GPUInfo
[]
>
{
const
totalResourceMap
:
Map
<
RemoteMachineMeta
,
GPUInfo
[]
>
=
new
Map
<
RemoteMachineMeta
,
GPUInfo
[]
>
();
const
totalResourceMap
:
Map
<
RemoteMachineMeta
,
GPUInfo
[]
>
=
new
Map
<
RemoteMachineMeta
,
GPUInfo
[]
>
();
this
.
machineSSHClientMap
.
forEach
((
sshClientManager
:
SSHClientManager
,
rmMeta
:
RemoteMachineMeta
)
=>
{
this
.
machineSSHClientMap
.
forEach
((
sshClientManager
:
SSHClientManager
,
rmMeta
:
RemoteMachineMeta
)
=>
{
// Assgin totoal GPU count as init available GPU number
// Assgin totoal GPU count as init available GPU number
if
(
rmMeta
.
gpuSummary
!==
undefined
)
{
if
(
rmMeta
.
gpuSummary
!==
undefined
)
{
...
@@ -149,7 +148,6 @@ export class GPUScheduler {
...
@@ -149,7 +148,6 @@ export class GPUScheduler {
}
}
}
}
this
.
log
.
debug
(
`designated gpu indices:
${
designatedGpuIndices
}
`
);
this
.
log
.
debug
(
`designated gpu indices:
${
designatedGpuIndices
}
`
);
// tslint:disable: strict-boolean-expressions
rmMeta
.
gpuSummary
.
gpuInfos
.
forEach
((
gpuInfo
:
GPUInfo
)
=>
{
rmMeta
.
gpuSummary
.
gpuInfos
.
forEach
((
gpuInfo
:
GPUInfo
)
=>
{
// if the GPU has active process, OR be reserved by a job,
// if the GPU has active process, OR be reserved by a job,
// or index not in gpuIndices configuration in machineList,
// or index not in gpuIndices configuration in machineList,
...
@@ -175,7 +173,6 @@ export class GPUScheduler {
...
@@ -175,7 +173,6 @@ export class GPUScheduler {
return
totalResourceMap
;
return
totalResourceMap
;
}
}
// tslint:enable: strict-boolean-expressions
private
selectMachine
(
rmMetas
:
RemoteMachineMeta
[]):
RemoteMachineMeta
{
private
selectMachine
(
rmMetas
:
RemoteMachineMeta
[]):
RemoteMachineMeta
{
assert
(
rmMetas
!==
undefined
&&
rmMetas
.
length
>
0
);
assert
(
rmMetas
!==
undefined
&&
rmMetas
.
length
>
0
);
...
@@ -224,11 +221,11 @@ export class GPUScheduler {
...
@@ -224,11 +221,11 @@ export class GPUScheduler {
resultType
:
ScheduleResultType
.
SUCCEED
,
resultType
:
ScheduleResultType
.
SUCCEED
,
scheduleInfo
:
{
scheduleInfo
:
{
rmMeta
:
rmMeta
,
rmMeta
:
rmMeta
,
cuda
_v
isible
_d
evice
:
allocatedGPUs
cuda
V
isible
D
evice
:
allocatedGPUs
.
map
((
gpuInfo
:
GPUInfo
)
=>
{
.
map
((
gpuInfo
:
GPUInfo
)
=>
{
return
gpuInfo
.
index
;
return
gpuInfo
.
index
;
})
})
.
join
(
'
,
'
)
.
join
(
'
,
'
)
}
}
};
};
}
}
...
...
src/nni_manager/training_service/remote_machine/remoteMachineData.ts
View file @
543239c6
...
@@ -13,13 +13,13 @@ import { GPUInfo, GPUSummary } from '../common/gpuData';
...
@@ -13,13 +13,13 @@ import { GPUInfo, GPUSummary } from '../common/gpuData';
* Metadata of remote machine for configuration and statuc query
* Metadata of remote machine for configuration and statuc query
*/
*/
export
class
RemoteMachineMeta
{
export
class
RemoteMachineMeta
{
public
readonly
ip
:
string
=
''
;
public
readonly
ip
:
string
=
''
;
public
readonly
port
:
number
=
22
;
public
readonly
port
:
number
=
22
;
public
readonly
username
:
string
=
''
;
public
readonly
username
:
string
=
''
;
public
readonly
passwd
:
string
=
''
;
public
readonly
passwd
:
string
=
''
;
public
readonly
sshKeyPath
?:
string
;
public
readonly
sshKeyPath
?:
string
;
public
readonly
passphrase
?:
string
;
public
readonly
passphrase
?:
string
;
public
gpuSummary
:
GPUSummary
|
undefined
;
public
gpuSummary
:
GPUSummary
|
undefined
;
public
readonly
gpuIndices
?:
string
;
public
readonly
gpuIndices
?:
string
;
public
readonly
maxTrialNumPerGpu
?:
number
;
public
readonly
maxTrialNumPerGpu
?:
number
;
//TODO: initialize varialbe in constructor
//TODO: initialize varialbe in constructor
...
@@ -43,11 +43,11 @@ export function parseGpuIndices(gpuIndices?: string): Set<number> | undefined {
...
@@ -43,11 +43,11 @@ export function parseGpuIndices(gpuIndices?: string): Set<number> | undefined {
* The execution result for command executed on remote machine
* The execution result for command executed on remote machine
*/
*/
export
class
RemoteCommandResult
{
export
class
RemoteCommandResult
{
public
readonly
stdout
:
string
;
public
readonly
stdout
:
string
;
public
readonly
stderr
:
string
;
public
readonly
stderr
:
string
;
public
readonly
exitCode
:
number
;
public
readonly
exitCode
:
number
;
constructor
(
stdout
:
string
,
stderr
:
string
,
exitCode
:
number
)
{
constructor
(
stdout
:
string
,
stderr
:
string
,
exitCode
:
number
)
{
this
.
stdout
=
stdout
;
this
.
stdout
=
stdout
;
this
.
stderr
=
stderr
;
this
.
stderr
=
stderr
;
this
.
exitCode
=
exitCode
;
this
.
exitCode
=
exitCode
;
...
@@ -186,7 +186,6 @@ export class SSHClientManager {
...
@@ -186,7 +186,6 @@ export class SSHClientManager {
/**
/**
* Create a new ssh connection client and initialize it
* Create a new ssh connection client and initialize it
*/
*/
// tslint:disable:non-literal-fs-path
private
initNewSSHClient
():
Promise
<
Client
>
{
private
initNewSSHClient
():
Promise
<
Client
>
{
const
deferred
:
Deferred
<
Client
>
=
new
Deferred
<
Client
>
();
const
deferred
:
Deferred
<
Client
>
=
new
Deferred
<
Client
>
();
const
conn
:
Client
=
new
Client
();
const
conn
:
Client
=
new
Client
();
...
@@ -225,9 +224,9 @@ export class SSHClientManager {
...
@@ -225,9 +224,9 @@ export class SSHClientManager {
}
}
}
}
export
type
RemoteMachineScheduleResult
=
{
scheduleInfo
:
RemoteMachineScheduleInfo
|
undefined
;
resultType
:
ScheduleResultType
};
export
type
RemoteMachineScheduleResult
=
{
scheduleInfo
:
RemoteMachineScheduleInfo
|
undefined
;
resultType
:
ScheduleResultType
};
export
type
RemoteMachineScheduleInfo
=
{
rmMeta
:
RemoteMachineMeta
;
cuda
_v
isible
_d
evice
:
string
};
export
type
RemoteMachineScheduleInfo
=
{
rmMeta
:
RemoteMachineMeta
;
cuda
V
isible
D
evice
:
string
};
export
enum
ScheduleResultType
{
export
enum
ScheduleResultType
{
// Schedule succeeded
// Schedule succeeded
...
...
src/nni_manager/training_service/remote_machine/remoteMachineJobRestServer.ts
View file @
543239c6
...
@@ -15,7 +15,7 @@ import { RemoteMachineTrainingService } from './remoteMachineTrainingService';
...
@@ -15,7 +15,7 @@ import { RemoteMachineTrainingService } from './remoteMachineTrainingService';
@
component
.
Singleton
@
component
.
Singleton
export
class
RemoteMachineJobRestServer
extends
ClusterJobRestServer
{
export
class
RemoteMachineJobRestServer
extends
ClusterJobRestServer
{
@
Inject
@
Inject
private
readonly
remoteMachineTrainingService
:
RemoteMachineTrainingService
;
private
readonly
remoteMachineTrainingService
:
RemoteMachineTrainingService
;
/**
/**
* constructor to provide NNIRestServer's own rest property, e.g. port
* constructor to provide NNIRestServer's own rest property, e.g. port
...
@@ -25,8 +25,7 @@ export class RemoteMachineJobRestServer extends ClusterJobRestServer {
...
@@ -25,8 +25,7 @@ export class RemoteMachineJobRestServer extends ClusterJobRestServer {
this
.
remoteMachineTrainingService
=
component
.
get
(
RemoteMachineTrainingService
);
this
.
remoteMachineTrainingService
=
component
.
get
(
RemoteMachineTrainingService
);
}
}
// tslint:disable-next-line:no-any
protected
handleTrialMetrics
(
jobId
:
string
,
metrics
:
any
[]):
void
{
protected
handleTrialMetrics
(
jobId
:
string
,
metrics
:
any
[])
:
void
{
// Split metrics array into single metric, then emit
// Split metrics array into single metric, then emit
// Warning: If not split metrics into single ones, the behavior will be UNKNOWNls
// Warning: If not split metrics into single ones, the behavior will be UNKNOWNls
for
(
const
singleMetric
of
metrics
)
{
for
(
const
singleMetric
of
metrics
)
{
...
...
src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts
View file @
543239c6
...
@@ -4,12 +4,10 @@
...
@@ -4,12 +4,10 @@
'
use strict
'
;
'
use strict
'
;
import
*
as
assert
from
'
assert
'
;
import
*
as
assert
from
'
assert
'
;
import
*
as
cpp
from
'
child-process-promise
'
;
import
{
EventEmitter
}
from
'
events
'
;
import
{
EventEmitter
}
from
'
events
'
;
import
*
as
fs
from
'
fs
'
;
import
*
as
fs
from
'
fs
'
;
import
*
as
os
from
'
os
'
;
import
*
as
path
from
'
path
'
;
import
*
as
path
from
'
path
'
;
import
{
Client
,
ConnectConfig
}
from
'
ssh2
'
;
import
{
Client
}
from
'
ssh2
'
;
import
{
Deferred
}
from
'
ts-deferred
'
;
import
{
Deferred
}
from
'
ts-deferred
'
;
import
{
String
}
from
'
typescript-string-operations
'
;
import
{
String
}
from
'
typescript-string-operations
'
;
import
*
as
component
from
'
../../common/component
'
;
import
*
as
component
from
'
../../common/component
'
;
...
@@ -29,12 +27,12 @@ import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
...
@@ -29,12 +27,12 @@ import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import
{
GPUSummary
}
from
'
../common/gpuData
'
;
import
{
GPUSummary
}
from
'
../common/gpuData
'
;
import
{
TrialConfig
}
from
'
../common/trialConfig
'
;
import
{
TrialConfig
}
from
'
../common/trialConfig
'
;
import
{
TrialConfigMetadataKey
}
from
'
../common/trialConfigMetadataKey
'
;
import
{
TrialConfigMetadataKey
}
from
'
../common/trialConfigMetadataKey
'
;
import
{
execCopydir
,
execMkdir
,
execRemove
,
validateCodeDir
,
getGpuMetricsCollectorBashScriptContent
}
from
'
../common/util
'
;
import
{
execCopydir
,
execMkdir
,
validateCodeDir
,
getGpuMetricsCollectorBashScriptContent
}
from
'
../common/util
'
;
import
{
GPUScheduler
}
from
'
./gpuScheduler
'
;
import
{
GPUScheduler
}
from
'
./gpuScheduler
'
;
import
{
import
{
HOST_JOB_SHELL_FORMAT
,
RemoteCommandResult
,
REMOTEMACHINE_TRIAL_COMMAND_FORMAT
,
RemoteMachineMeta
,
RemoteCommandResult
,
REMOTEMACHINE_TRIAL_COMMAND_FORMAT
,
RemoteMachineMeta
,
RemoteMachineScheduleInfo
,
RemoteMachineScheduleResult
,
RemoteMachineTrialJobDetail
,
RemoteMachineScheduleInfo
,
RemoteMachineScheduleResult
,
RemoteMachineTrialJobDetail
,
ScheduleResultType
,
SSHClient
,
SSHClientManager
ScheduleResultType
,
SSHClientManager
}
from
'
./remoteMachineData
'
;
}
from
'
./remoteMachineData
'
;
import
{
RemoteMachineJobRestServer
}
from
'
./remoteMachineJobRestServer
'
;
import
{
RemoteMachineJobRestServer
}
from
'
./remoteMachineJobRestServer
'
;
import
{
SSHClientUtility
}
from
'
./sshClientUtility
'
;
import
{
SSHClientUtility
}
from
'
./sshClientUtility
'
;
...
@@ -93,7 +91,7 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -93,7 +91,7 @@ class RemoteMachineTrainingService implements TrainingService {
while
(
this
.
jobQueue
.
length
>
0
)
{
while
(
this
.
jobQueue
.
length
>
0
)
{
this
.
updateGpuReservation
();
this
.
updateGpuReservation
();
const
trialJobId
:
string
=
this
.
jobQueue
[
0
];
const
trialJobId
:
string
=
this
.
jobQueue
[
0
];
const
prepareResult
:
boolean
=
await
this
.
prepareTrialJob
(
trialJobId
);
const
prepareResult
:
boolean
=
await
this
.
prepareTrialJob
(
trialJobId
);
if
(
prepareResult
)
{
if
(
prepareResult
)
{
// Remove trial job with trialJobId from job queue
// Remove trial job with trialJobId from job queue
this
.
jobQueue
.
shift
();
this
.
jobQueue
.
shift
();
...
@@ -208,7 +206,6 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -208,7 +206,6 @@ class RemoteMachineTrainingService implements TrainingService {
* Submit trial job
* Submit trial job
* @param form trial job description form
* @param form trial job description form
*/
*/
// tslint:disable-next-line:informative-docs
public
async
submitTrialJob
(
form
:
TrialJobApplicationForm
):
Promise
<
TrialJobDetail
>
{
public
async
submitTrialJob
(
form
:
TrialJobApplicationForm
):
Promise
<
TrialJobDetail
>
{
if
(
this
.
trialConfig
===
undefined
)
{
if
(
this
.
trialConfig
===
undefined
)
{
throw
new
Error
(
'
trial config is not initialized
'
);
throw
new
Error
(
'
trial config is not initialized
'
);
...
@@ -241,12 +238,7 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -241,12 +238,7 @@ class RemoteMachineTrainingService implements TrainingService {
if
(
trialJobDetail
===
undefined
)
{
if
(
trialJobDetail
===
undefined
)
{
throw
new
Error
(
`updateTrialJob failed:
${
trialJobId
}
not found`
);
throw
new
Error
(
`updateTrialJob failed:
${
trialJobId
}
not found`
);
}
}
const
rmMeta
:
RemoteMachineMeta
|
undefined
=
(
<
RemoteMachineTrialJobDetail
>
trialJobDetail
).
rmMeta
;
await
this
.
writeParameterFile
(
trialJobId
,
form
.
hyperParameters
);
if
(
rmMeta
!==
undefined
)
{
await
this
.
writeParameterFile
(
trialJobId
,
form
.
hyperParameters
,
rmMeta
);
}
else
{
throw
new
Error
(
`updateTrialJob failed:
${
trialJobId
}
rmMeta not found`
);
}
return
trialJobDetail
;
return
trialJobDetail
;
}
}
...
@@ -262,7 +254,6 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -262,7 +254,6 @@ class RemoteMachineTrainingService implements TrainingService {
* Cancel trial job
* Cancel trial job
* @param trialJobId ID of trial job
* @param trialJobId ID of trial job
*/
*/
// tslint:disable:informative-docs no-unsafe-any
public
async
cancelTrialJob
(
trialJobId
:
string
,
isEarlyStopped
:
boolean
=
false
):
Promise
<
void
>
{
public
async
cancelTrialJob
(
trialJobId
:
string
,
isEarlyStopped
:
boolean
=
false
):
Promise
<
void
>
{
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
trialJob
:
RemoteMachineTrialJobDetail
|
undefined
=
this
.
trialJobsMap
.
get
(
trialJobId
);
const
trialJob
:
RemoteMachineTrialJobDetail
|
undefined
=
this
.
trialJobsMap
.
get
(
trialJobId
);
...
@@ -272,7 +263,7 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -272,7 +263,7 @@ class RemoteMachineTrainingService implements TrainingService {
}
}
// Remove the job with trialJobId from job queue
// Remove the job with trialJobId from job queue
const
index
:
number
=
this
.
jobQueue
.
indexOf
(
trialJobId
);
const
index
:
number
=
this
.
jobQueue
.
indexOf
(
trialJobId
);
if
(
index
>=
0
)
{
if
(
index
>=
0
)
{
this
.
jobQueue
.
splice
(
index
,
1
);
this
.
jobQueue
.
splice
(
index
,
1
);
}
}
...
@@ -319,14 +310,13 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -319,14 +310,13 @@ class RemoteMachineTrainingService implements TrainingService {
await
this
.
setupConnections
(
value
);
await
this
.
setupConnections
(
value
);
this
.
gpuScheduler
=
new
GPUScheduler
(
this
.
machineSSHClientMap
);
this
.
gpuScheduler
=
new
GPUScheduler
(
this
.
machineSSHClientMap
);
break
;
break
;
case
TrialConfigMetadataKey
.
TRIAL_CONFIG
:
case
TrialConfigMetadataKey
.
TRIAL_CONFIG
:
{
const
remoteMachineTrailConfig
:
TrialConfig
=
<
TrialConfig
>
JSON
.
parse
(
value
);
const
remoteMachineTrailConfig
:
TrialConfig
=
<
TrialConfig
>
JSON
.
parse
(
value
);
// Parse trial config failed, throw Error
// Parse trial config failed, throw Error
if
(
remoteMachineTrailConfig
===
undefined
)
{
if
(
remoteMachineTrailConfig
===
undefined
)
{
throw
new
Error
(
'
trial config parsed failed
'
);
throw
new
Error
(
'
trial config parsed failed
'
);
}
}
// codeDir is not a valid directory, throw Error
// codeDir is not a valid directory, throw Error
// tslint:disable-next-line:non-literal-fs-path
if
(
!
fs
.
lstatSync
(
remoteMachineTrailConfig
.
codeDir
)
if
(
!
fs
.
lstatSync
(
remoteMachineTrailConfig
.
codeDir
)
.
isDirectory
())
{
.
isDirectory
())
{
throw
new
Error
(
`codeDir
${
remoteMachineTrailConfig
.
codeDir
}
is not a directory`
);
throw
new
Error
(
`codeDir
${
remoteMachineTrailConfig
.
codeDir
}
is not a directory`
);
...
@@ -343,6 +333,7 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -343,6 +333,7 @@ class RemoteMachineTrainingService implements TrainingService {
this
.
trialConfig
=
remoteMachineTrailConfig
;
this
.
trialConfig
=
remoteMachineTrailConfig
;
break
;
break
;
}
case
TrialConfigMetadataKey
.
MULTI_PHASE
:
case
TrialConfigMetadataKey
.
MULTI_PHASE
:
this
.
isMultiPhase
=
(
value
===
'
true
'
||
value
===
'
True
'
);
this
.
isMultiPhase
=
(
value
===
'
true
'
||
value
===
'
True
'
);
break
;
break
;
...
@@ -444,7 +435,6 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -444,7 +435,6 @@ class RemoteMachineTrainingService implements TrainingService {
await
SSHClientUtility
.
remoteExeCommand
(
`chmod 777
${
nniRootDir
}
${
nniRootDir
}
/*
${
nniRootDir
}
/scripts/*`
,
conn
);
await
SSHClientUtility
.
remoteExeCommand
(
`chmod 777
${
nniRootDir
}
${
nniRootDir
}
/*
${
nniRootDir
}
/scripts/*`
,
conn
);
//Begin to execute gpu_metrics_collection scripts
//Begin to execute gpu_metrics_collection scripts
// tslint:disable-next-line: no-floating-promises
const
script
=
getGpuMetricsCollectorBashScriptContent
(
remoteGpuScriptCollectorDir
);
const
script
=
getGpuMetricsCollectorBashScriptContent
(
remoteGpuScriptCollectorDir
);
SSHClientUtility
.
remoteExeCommand
(
`bash -c '
${
script
}
'`
,
conn
);
SSHClientUtility
.
remoteExeCommand
(
`bash -c '
${
script
}
'`
,
conn
);
...
@@ -464,7 +454,7 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -464,7 +454,7 @@ class RemoteMachineTrainingService implements TrainingService {
}
}
private
async
prepareTrialJob
(
trialJobId
:
string
):
Promise
<
boolean
>
{
private
async
prepareTrialJob
(
trialJobId
:
string
):
Promise
<
boolean
>
{
const
deferred
:
Deferred
<
boolean
>
=
new
Deferred
<
boolean
>
();
const
deferred
:
Deferred
<
boolean
>
=
new
Deferred
<
boolean
>
();
if
(
this
.
trialConfig
===
undefined
)
{
if
(
this
.
trialConfig
===
undefined
)
{
throw
new
Error
(
'
trial config is not initialized
'
);
throw
new
Error
(
'
trial config is not initialized
'
);
...
@@ -485,13 +475,13 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -485,13 +475,13 @@ class RemoteMachineTrainingService implements TrainingService {
// get an ssh client from scheduler
// get an ssh client from scheduler
const
rmScheduleResult
:
RemoteMachineScheduleResult
=
this
.
gpuScheduler
.
scheduleMachine
(
this
.
trialConfig
.
gpuNum
,
trialJobDetail
);
const
rmScheduleResult
:
RemoteMachineScheduleResult
=
this
.
gpuScheduler
.
scheduleMachine
(
this
.
trialConfig
.
gpuNum
,
trialJobDetail
);
if
(
rmScheduleResult
.
resultType
===
ScheduleResultType
.
REQUIRE_EXCEED_TOTAL
)
{
if
(
rmScheduleResult
.
resultType
===
ScheduleResultType
.
REQUIRE_EXCEED_TOTAL
)
{
const
errorMessage
:
string
=
`Required GPU number
${
this
.
trialConfig
.
gpuNum
}
is too large, no machine can meet`
;
const
errorMessage
:
string
=
`Required GPU number
${
this
.
trialConfig
.
gpuNum
}
is too large, no machine can meet`
;
this
.
log
.
error
(
errorMessage
);
this
.
log
.
error
(
errorMessage
);
deferred
.
reject
();
deferred
.
reject
();
throw
new
NNIError
(
NNIErrorNames
.
RESOURCE_NOT_AVAILABLE
,
errorMessage
);
throw
new
NNIError
(
NNIErrorNames
.
RESOURCE_NOT_AVAILABLE
,
errorMessage
);
}
else
if
(
rmScheduleResult
.
resultType
===
ScheduleResultType
.
SUCCEED
}
else
if
(
rmScheduleResult
.
resultType
===
ScheduleResultType
.
SUCCEED
&&
rmScheduleResult
.
scheduleInfo
!==
undefined
)
{
&&
rmScheduleResult
.
scheduleInfo
!==
undefined
)
{
const
rmScheduleInfo
:
RemoteMachineScheduleInfo
=
rmScheduleResult
.
scheduleInfo
;
const
rmScheduleInfo
:
RemoteMachineScheduleInfo
=
rmScheduleResult
.
scheduleInfo
;
const
trialWorkingFolder
:
string
=
unixPathJoin
(
this
.
remoteExpRootDir
,
'
trials
'
,
trialJobId
);
const
trialWorkingFolder
:
string
=
unixPathJoin
(
this
.
remoteExpRootDir
,
'
trials
'
,
trialJobId
);
trialJobDetail
.
rmMeta
=
rmScheduleInfo
.
rmMeta
;
trialJobDetail
.
rmMeta
=
rmScheduleInfo
.
rmMeta
;
...
@@ -521,7 +511,7 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -521,7 +511,7 @@ class RemoteMachineTrainingService implements TrainingService {
if
(
this
.
trialConfig
===
undefined
)
{
if
(
this
.
trialConfig
===
undefined
)
{
throw
new
Error
(
'
trial config is not initialized
'
);
throw
new
Error
(
'
trial config is not initialized
'
);
}
}
const
cuda
_v
isible
_d
evice
:
string
=
rmScheduleInfo
.
cuda
_v
isible
_d
evice
;
const
cuda
V
isible
D
evice
:
string
=
rmScheduleInfo
.
cuda
V
isible
D
evice
;
const
sshClient
:
Client
|
undefined
=
this
.
trialSSHClientMap
.
get
(
trialJobId
);
const
sshClient
:
Client
|
undefined
=
this
.
trialSSHClientMap
.
get
(
trialJobId
);
if
(
sshClient
===
undefined
)
{
if
(
sshClient
===
undefined
)
{
assert
(
false
,
'
sshClient is undefined.
'
);
assert
(
false
,
'
sshClient is undefined.
'
);
...
@@ -543,19 +533,18 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -543,19 +533,18 @@ class RemoteMachineTrainingService implements TrainingService {
// See definition in remoteMachineData.ts
// See definition in remoteMachineData.ts
let
command
:
string
;
let
command
:
string
;
// Set CUDA_VISIBLE_DEVICES environment variable based on cuda
_v
isible
_d
evice
// Set CUDA_VISIBLE_DEVICES environment variable based on cuda
V
isible
D
evice
// If no valid cuda
_v
isible
_d
evice is defined, set CUDA_VISIBLE_DEVICES to empty string to hide GPU device
// If no valid cuda
V
isible
D
evice is defined, set CUDA_VISIBLE_DEVICES to empty string to hide GPU device
// If gpuNum is undefined, will not set CUDA_VISIBLE_DEVICES in script
// If gpuNum is undefined, will not set CUDA_VISIBLE_DEVICES in script
if
(
this
.
trialConfig
.
gpuNum
===
undefined
)
{
if
(
this
.
trialConfig
.
gpuNum
===
undefined
)
{
command
=
this
.
trialConfig
.
command
;
command
=
this
.
trialConfig
.
command
;
}
else
{
}
else
{
if
(
typeof
cuda
_v
isible
_d
evice
===
'
string
'
&&
cuda
_v
isible
_d
evice
.
length
>
0
)
{
if
(
typeof
cuda
V
isible
D
evice
===
'
string
'
&&
cuda
V
isible
D
evice
.
length
>
0
)
{
command
=
`CUDA_VISIBLE_DEVICES=
${
cuda
_v
isible
_d
evice
}
${
this
.
trialConfig
.
command
}
`
;
command
=
`CUDA_VISIBLE_DEVICES=
${
cuda
V
isible
D
evice
}
${
this
.
trialConfig
.
command
}
`
;
}
else
{
}
else
{
command
=
`CUDA_VISIBLE_DEVICES=" "
${
this
.
trialConfig
.
command
}
`
;
command
=
`CUDA_VISIBLE_DEVICES=" "
${
this
.
trialConfig
.
command
}
`
;
}
}
}
}
// tslint:disable-next-line: strict-boolean-expressions
const
nniManagerIp
:
string
=
this
.
nniManagerIpConfig
?
this
.
nniManagerIpConfig
.
nniManagerIp
:
getIPV4Address
();
const
nniManagerIp
:
string
=
this
.
nniManagerIpConfig
?
this
.
nniManagerIpConfig
.
nniManagerIp
:
getIPV4Address
();
if
(
this
.
remoteRestServerPort
===
undefined
)
{
if
(
this
.
remoteRestServerPort
===
undefined
)
{
const
restServer
:
RemoteMachineJobRestServer
=
component
.
get
(
RemoteMachineJobRestServer
);
const
restServer
:
RemoteMachineJobRestServer
=
component
.
get
(
RemoteMachineJobRestServer
);
...
@@ -584,16 +573,15 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -584,16 +573,15 @@ class RemoteMachineTrainingService implements TrainingService {
//create tmp trial working folder locally.
//create tmp trial working folder locally.
await
execCopydir
(
this
.
trialConfig
.
codeDir
,
trialLocalTempFolder
);
await
execCopydir
(
this
.
trialConfig
.
codeDir
,
trialLocalTempFolder
);
const
installScriptContent
:
string
=
CONTAINER_INSTALL_NNI_SHELL_FORMAT
;
const
installScriptContent
:
string
=
CONTAINER_INSTALL_NNI_SHELL_FORMAT
;
// Write NNI installation file to local tmp files
// Write NNI installation file to local tmp files
await
fs
.
promises
.
writeFile
(
path
.
join
(
trialLocalTempFolder
,
'
install_nni.sh
'
),
installScriptContent
,
{
encoding
:
'
utf8
'
});
await
fs
.
promises
.
writeFile
(
path
.
join
(
trialLocalTempFolder
,
'
install_nni.sh
'
),
installScriptContent
,
{
encoding
:
'
utf8
'
});
// Write file content ( run.sh and parameter.cfg ) to local tmp files
// Write file content ( run.sh and parameter.cfg ) to local tmp files
await
fs
.
promises
.
writeFile
(
path
.
join
(
trialLocalTempFolder
,
'
run.sh
'
),
runScriptTrialContent
,
{
encoding
:
'
utf8
'
});
await
fs
.
promises
.
writeFile
(
path
.
join
(
trialLocalTempFolder
,
'
run.sh
'
),
runScriptTrialContent
,
{
encoding
:
'
utf8
'
});
await
this
.
writeParameterFile
(
trialJobId
,
form
.
hyperParameters
,
rmScheduleInfo
.
rmMeta
);
await
this
.
writeParameterFile
(
trialJobId
,
form
.
hyperParameters
);
// Copy files in codeDir to remote working directory
// Copy files in codeDir to remote working directory
await
SSHClientUtility
.
copyDirectoryToRemote
(
trialLocalTempFolder
,
trialWorkingFolder
,
sshClient
,
this
.
remoteOS
);
await
SSHClientUtility
.
copyDirectoryToRemote
(
trialLocalTempFolder
,
trialWorkingFolder
,
sshClient
,
this
.
remoteOS
);
// Execute command in remote machine
// Execute command in remote machine
// tslint:disable-next-line: no-floating-promises
SSHClientUtility
.
remoteExeCommand
(
`bash
${
unixPathJoin
(
trialWorkingFolder
,
'
run.sh
'
)}
`
,
sshClient
);
SSHClientUtility
.
remoteExeCommand
(
`bash
${
unixPathJoin
(
trialWorkingFolder
,
'
run.sh
'
)}
`
,
sshClient
);
}
}
...
@@ -610,6 +598,7 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -610,6 +598,7 @@ class RemoteMachineTrainingService implements TrainingService {
const
deferred
:
Deferred
<
TrialJobDetail
>
=
new
Deferred
<
TrialJobDetail
>
();
const
deferred
:
Deferred
<
TrialJobDetail
>
=
new
Deferred
<
TrialJobDetail
>
();
const
jobpidPath
:
string
=
this
.
getJobPidPath
(
trialJob
.
id
);
const
jobpidPath
:
string
=
this
.
getJobPidPath
(
trialJob
.
id
);
const
trialReturnCodeFilePath
:
string
=
unixPathJoin
(
this
.
remoteExpRootDir
,
'
trials
'
,
trialJob
.
id
,
'
.nni
'
,
'
code
'
);
const
trialReturnCodeFilePath
:
string
=
unixPathJoin
(
this
.
remoteExpRootDir
,
'
trials
'
,
trialJob
.
id
,
'
.nni
'
,
'
code
'
);
/* eslint-disable require-atomic-updates */
try
{
try
{
const
killResult
:
number
=
(
await
SSHClientUtility
.
remoteExeCommand
(
`kill -0
\`
cat
${
jobpidPath
}
\`
`
,
sshClient
)).
exitCode
;
const
killResult
:
number
=
(
await
SSHClientUtility
.
remoteExeCommand
(
`kill -0
\`
cat
${
jobpidPath
}
\`
`
,
sshClient
)).
exitCode
;
// if the process of jobpid is not alive any more
// if the process of jobpid is not alive any more
...
@@ -646,7 +635,7 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -646,7 +635,7 @@ class RemoteMachineTrainingService implements TrainingService {
deferred
.
resolve
(
trialJob
);
deferred
.
resolve
(
trialJob
);
}
}
}
}
/* eslint-enable require-atomic-updates */
return
deferred
.
promise
;
return
deferred
.
promise
;
}
}
...
@@ -662,7 +651,7 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -662,7 +651,7 @@ class RemoteMachineTrainingService implements TrainingService {
return
unixPathJoin
(
getRemoteTmpDir
(
this
.
remoteOS
),
'
nni
'
,
'
experiments
'
,
getExperimentId
());
return
unixPathJoin
(
getRemoteTmpDir
(
this
.
remoteOS
),
'
nni
'
,
'
experiments
'
,
getExperimentId
());
}
}
public
get
MetricsEmitter
()
:
EventEmitter
{
public
get
MetricsEmitter
():
EventEmitter
{
return
this
.
metricsEmitter
;
return
this
.
metricsEmitter
;
}
}
...
@@ -672,13 +661,10 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -672,13 +661,10 @@ class RemoteMachineTrainingService implements TrainingService {
throw
new
NNIError
(
NNIErrorNames
.
INVALID_JOB_DETAIL
,
`Invalid job detail information for trial job
${
jobId
}
`
);
throw
new
NNIError
(
NNIErrorNames
.
INVALID_JOB_DETAIL
,
`Invalid job detail information for trial job
${
jobId
}
`
);
}
}
let
jobpidPath
:
string
;
return
unixPathJoin
(
trialJobDetail
.
workingDirectory
,
'
.nni
'
,
'
jobpid
'
);
jobpidPath
=
unixPathJoin
(
trialJobDetail
.
workingDirectory
,
'
.nni
'
,
'
jobpid
'
);
return
jobpidPath
;
}
}
private
async
writeParameterFile
(
trialJobId
:
string
,
hyperParameters
:
HyperParameters
,
rmMeta
:
RemoteMachineMeta
):
Promise
<
void
>
{
private
async
writeParameterFile
(
trialJobId
:
string
,
hyperParameters
:
HyperParameters
):
Promise
<
void
>
{
const
sshClient
:
Client
|
undefined
=
this
.
trialSSHClientMap
.
get
(
trialJobId
);
const
sshClient
:
Client
|
undefined
=
this
.
trialSSHClientMap
.
get
(
trialJobId
);
if
(
sshClient
===
undefined
)
{
if
(
sshClient
===
undefined
)
{
throw
new
Error
(
'
sshClient is undefined.
'
);
throw
new
Error
(
'
sshClient is undefined.
'
);
...
...
src/nni_manager/training_service/remote_machine/sshClientUtility.ts
View file @
543239c6
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
'
use strict
'
;
'
use strict
'
;
import
*
as
assert
from
'
assert
'
;
import
*
as
assert
from
'
assert
'
;
import
*
as
cpp
from
'
child-process-promise
'
;
import
*
as
os
from
'
os
'
;
import
*
as
os
from
'
os
'
;
import
*
as
path
from
'
path
'
;
import
*
as
path
from
'
path
'
;
import
{
Client
,
ClientChannel
,
SFTPWrapper
}
from
'
ssh2
'
;
import
{
Client
,
ClientChannel
,
SFTPWrapper
}
from
'
ssh2
'
;
...
@@ -22,44 +21,18 @@ import { RemoteCommandResult } from './remoteMachineData';
...
@@ -22,44 +21,18 @@ import { RemoteCommandResult } from './remoteMachineData';
*
*
*/
*/
export
namespace
SSHClientUtility
{
export
namespace
SSHClientUtility
{
/**
* Copy files and directories in local directory recursively to remote directory
* @param localDirectory local diretory
* @param remoteDirectory remote directory
* @param sshClient SSH client
*/
export
async
function
copyDirectoryToRemote
(
localDirectory
:
string
,
remoteDirectory
:
string
,
sshClient
:
Client
,
remoteOS
:
string
)
:
Promise
<
void
>
{
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
tmpTarName
:
string
=
`
${
uniqueString
(
10
)}
.tar.gz`
;
const
localTarPath
:
string
=
path
.
join
(
os
.
tmpdir
(),
tmpTarName
);
const
remoteTarPath
:
string
=
unixPathJoin
(
getRemoteTmpDir
(
remoteOS
),
tmpTarName
);
// Compress files in local directory to experiment root directory
await
tarAdd
(
localTarPath
,
localDirectory
);
// Copy the compressed file to remoteDirectory and delete it
await
copyFileToRemote
(
localTarPath
,
remoteTarPath
,
sshClient
);
await
execRemove
(
localTarPath
);
// Decompress the remote compressed file in and delete it
await
remoteExeCommand
(
`tar -oxzf
${
remoteTarPath
}
-C
${
remoteDirectory
}
`
,
sshClient
);
await
remoteExeCommand
(
`rm
${
remoteTarPath
}
`
,
sshClient
);
deferred
.
resolve
();
return
deferred
.
promise
;
}
/**
/**
* Copy local file to remote path
* Copy local file to remote path
* @param localFilePath the path of local file
* @param localFilePath the path of local file
* @param remoteFilePath the target path in remote machine
* @param remoteFilePath the target path in remote machine
* @param sshClient SSH Client
* @param sshClient SSH Client
*/
*/
export
function
copyFileToRemote
(
localFilePath
:
string
,
remoteFilePath
:
string
,
sshClient
:
Client
)
:
Promise
<
boolean
>
{
export
function
copyFileToRemote
(
localFilePath
:
string
,
remoteFilePath
:
string
,
sshClient
:
Client
):
Promise
<
boolean
>
{
const
log
:
Logger
=
getLogger
();
const
log
:
Logger
=
getLogger
();
log
.
debug
(
`copyFileToRemote: localFilePath:
${
localFilePath
}
, remoteFilePath:
${
remoteFilePath
}
`
);
log
.
debug
(
`copyFileToRemote: localFilePath:
${
localFilePath
}
, remoteFilePath:
${
remoteFilePath
}
`
);
assert
(
sshClient
!==
undefined
);
assert
(
sshClient
!==
undefined
);
const
deferred
:
Deferred
<
boolean
>
=
new
Deferred
<
boolean
>
();
const
deferred
:
Deferred
<
boolean
>
=
new
Deferred
<
boolean
>
();
sshClient
.
sftp
((
err
:
Error
,
sftp
:
SFTPWrapper
)
=>
{
sshClient
.
sftp
((
err
:
Error
,
sftp
:
SFTPWrapper
)
=>
{
if
(
err
!==
undefined
&&
err
!==
null
)
{
if
(
err
!==
undefined
&&
err
!==
null
)
{
log
.
error
(
`copyFileToRemote:
${
err
.
message
}
,
${
localFilePath
}
,
${
remoteFilePath
}
`
);
log
.
error
(
`copyFileToRemote:
${
err
.
message
}
,
${
localFilePath
}
,
${
remoteFilePath
}
`
);
deferred
.
reject
(
err
);
deferred
.
reject
(
err
);
...
@@ -67,7 +40,7 @@ export namespace SSHClientUtility {
...
@@ -67,7 +40,7 @@ export namespace SSHClientUtility {
return
;
return
;
}
}
assert
(
sftp
!==
undefined
);
assert
(
sftp
!==
undefined
);
sftp
.
fastPut
(
localFilePath
,
remoteFilePath
,
(
fastPutErr
:
Error
)
=>
{
sftp
.
fastPut
(
localFilePath
,
remoteFilePath
,
(
fastPutErr
:
Error
)
=>
{
sftp
.
end
();
sftp
.
end
();
if
(
fastPutErr
!==
undefined
&&
fastPutErr
!==
null
)
{
if
(
fastPutErr
!==
undefined
&&
fastPutErr
!==
null
)
{
deferred
.
reject
(
fastPutErr
);
deferred
.
reject
(
fastPutErr
);
...
@@ -85,16 +58,15 @@ export namespace SSHClientUtility {
...
@@ -85,16 +58,15 @@ export namespace SSHClientUtility {
* @param command the command to execute remotely
* @param command the command to execute remotely
* @param client SSH Client
* @param client SSH Client
*/
*/
// tslint:disable:no-unsafe-any no-any
export
function
remoteExeCommand
(
command
:
string
,
client
:
Client
):
Promise
<
RemoteCommandResult
>
{
export
function
remoteExeCommand
(
command
:
string
,
client
:
Client
):
Promise
<
RemoteCommandResult
>
{
const
log
:
Logger
=
getLogger
();
const
log
:
Logger
=
getLogger
();
log
.
debug
(
`remoteExeCommand: command: [
${
command
}
]`
);
log
.
debug
(
`remoteExeCommand: command: [
${
command
}
]`
);
const
deferred
:
Deferred
<
RemoteCommandResult
>
=
new
Deferred
<
RemoteCommandResult
>
();
const
deferred
:
Deferred
<
RemoteCommandResult
>
=
new
Deferred
<
RemoteCommandResult
>
();
let
stdout
:
string
=
''
;
let
stdout
:
string
=
''
;
let
stderr
:
string
=
''
;
let
stderr
:
string
=
''
;
let
exitCode
:
number
;
let
exitCode
:
number
;
client
.
exec
(
command
,
(
err
:
Error
,
channel
:
ClientChannel
)
=>
{
client
.
exec
(
command
,
(
err
:
Error
,
channel
:
ClientChannel
)
=>
{
if
(
err
!==
undefined
&&
err
!==
null
)
{
if
(
err
!==
undefined
&&
err
!==
null
)
{
log
.
error
(
`remoteExeCommand:
${
err
.
message
}
`
);
log
.
error
(
`remoteExeCommand:
${
err
.
message
}
`
);
deferred
.
reject
(
err
);
deferred
.
reject
(
err
);
...
@@ -102,14 +74,14 @@ export namespace SSHClientUtility {
...
@@ -102,14 +74,14 @@ export namespace SSHClientUtility {
return
;
return
;
}
}
channel
.
on
(
'
data
'
,
(
data
:
any
,
dataStderr
:
any
)
=>
{
channel
.
on
(
'
data
'
,
(
data
:
any
,
dataStderr
:
any
)
=>
{
if
(
dataStderr
!==
undefined
&&
dataStderr
!==
null
)
{
if
(
dataStderr
!==
undefined
&&
dataStderr
!==
null
)
{
stderr
+=
data
.
toString
();
stderr
+=
data
.
toString
();
}
else
{
}
else
{
stdout
+=
data
.
toString
();
stdout
+=
data
.
toString
();
}
}
})
})
.
on
(
'
exit
'
,
(
code
:
any
,
signal
:
any
)
=>
{
.
on
(
'
exit
'
,
(
code
:
any
,
signal
:
any
)
=>
{
exitCode
=
<
number
>
code
;
exitCode
=
<
number
>
code
;
deferred
.
resolve
({
deferred
.
resolve
({
stdout
:
stdout
,
stdout
:
stdout
,
...
@@ -122,9 +94,34 @@ export namespace SSHClientUtility {
...
@@ -122,9 +94,34 @@ export namespace SSHClientUtility {
return
deferred
.
promise
;
return
deferred
.
promise
;
}
}
/**
* Copy files and directories in local directory recursively to remote directory
* @param localDirectory local diretory
* @param remoteDirectory remote directory
* @param sshClient SSH client
*/
export
async
function
copyDirectoryToRemote
(
localDirectory
:
string
,
remoteDirectory
:
string
,
sshClient
:
Client
,
remoteOS
:
string
):
Promise
<
void
>
{
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
tmpTarName
:
string
=
`
${
uniqueString
(
10
)}
.tar.gz`
;
const
localTarPath
:
string
=
path
.
join
(
os
.
tmpdir
(),
tmpTarName
);
const
remoteTarPath
:
string
=
unixPathJoin
(
getRemoteTmpDir
(
remoteOS
),
tmpTarName
);
// Compress files in local directory to experiment root directory
await
tarAdd
(
localTarPath
,
localDirectory
);
// Copy the compressed file to remoteDirectory and delete it
await
copyFileToRemote
(
localTarPath
,
remoteTarPath
,
sshClient
);
await
execRemove
(
localTarPath
);
// Decompress the remote compressed file in and delete it
await
remoteExeCommand
(
`tar -oxzf
${
remoteTarPath
}
-C
${
remoteDirectory
}
`
,
sshClient
);
await
remoteExeCommand
(
`rm
${
remoteTarPath
}
`
,
sshClient
);
deferred
.
resolve
();
return
deferred
.
promise
;
}
export
function
getRemoteFileContent
(
filePath
:
string
,
sshClient
:
Client
):
Promise
<
string
>
{
export
function
getRemoteFileContent
(
filePath
:
string
,
sshClient
:
Client
):
Promise
<
string
>
{
const
deferred
:
Deferred
<
string
>
=
new
Deferred
<
string
>
();
const
deferred
:
Deferred
<
string
>
=
new
Deferred
<
string
>
();
sshClient
.
sftp
((
err
:
Error
,
sftp
:
SFTPWrapper
)
=>
{
sshClient
.
sftp
((
err
:
Error
,
sftp
:
SFTPWrapper
)
=>
{
if
(
err
!==
undefined
&&
err
!==
null
)
{
if
(
err
!==
undefined
&&
err
!==
null
)
{
getLogger
()
getLogger
()
.
error
(
`getRemoteFileContent:
${
err
.
message
}
`
);
.
error
(
`getRemoteFileContent:
${
err
.
message
}
`
);
...
@@ -133,10 +130,10 @@ export namespace SSHClientUtility {
...
@@ -133,10 +130,10 @@ export namespace SSHClientUtility {
return
;
return
;
}
}
try
{
try
{
const
sftpStream
:
stream
.
Readable
=
sftp
.
createReadStream
(
filePath
);
const
sftpStream
:
stream
.
Readable
=
sftp
.
createReadStream
(
filePath
);
let
dataBuffer
:
string
=
''
;
let
dataBuffer
:
string
=
''
;
sftpStream
.
on
(
'
data
'
,
(
data
:
Buffer
|
string
)
=>
{
sftpStream
.
on
(
'
data
'
,
(
data
:
Buffer
|
string
)
=>
{
dataBuffer
+=
data
;
dataBuffer
+=
data
;
})
})
.
on
(
'
error
'
,
(
streamErr
:
Error
)
=>
{
.
on
(
'
error
'
,
(
streamErr
:
Error
)
=>
{
...
@@ -158,5 +155,4 @@ export namespace SSHClientUtility {
...
@@ -158,5 +155,4 @@ export namespace SSHClientUtility {
return
deferred
.
promise
;
return
deferred
.
promise
;
}
}
// tslint:enable:no-unsafe-any no-any
}
}
src/nni_manager/yarn.lock
View file @
543239c6
...
@@ -703,7 +703,7 @@ buffer-stream-reader@^0.1.1:
...
@@ -703,7 +703,7 @@ buffer-stream-reader@^0.1.1:
version "0.1.1"
version "0.1.1"
resolved "https://registry.yarnpkg.com/buffer-stream-reader/-/buffer-stream-reader-0.1.1.tgz#ca8bf93631deedd8b8f8c3bb44991cc30951e259"
resolved "https://registry.yarnpkg.com/buffer-stream-reader/-/buffer-stream-reader-0.1.1.tgz#ca8bf93631deedd8b8f8c3bb44991cc30951e259"
builtin-modules@^1.0.0
, builtin-modules@^1.1.1
:
builtin-modules@^1.0.0:
version "1.1.1"
version "1.1.1"
resolved "https://registry.yarnpkg.com/builtin-modules/-/builtin-modules-1.1.1.tgz#270f076c5a72c02f5b65a47df94c5fe3a278892f"
resolved "https://registry.yarnpkg.com/builtin-modules/-/builtin-modules-1.1.1.tgz#270f076c5a72c02f5b65a47df94c5fe3a278892f"
...
@@ -841,7 +841,7 @@ chalk@^1.0.0:
...
@@ -841,7 +841,7 @@ chalk@^1.0.0:
strip-ansi "^3.0.0"
strip-ansi "^3.0.0"
supports-color "^2.0.0"
supports-color "^2.0.0"
chalk@^2.0.0
, chalk@^2.3.0
:
chalk@^2.0.0:
version "2.4.1"
version "2.4.1"
resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.4.1.tgz#18c49ab16a037b6eb0152cc83e3471338215b66e"
resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.4.1.tgz#18c49ab16a037b6eb0152cc83e3471338215b66e"
dependencies:
dependencies:
...
@@ -971,10 +971,6 @@ commander@2.15.1:
...
@@ -971,10 +971,6 @@ commander@2.15.1:
version "2.15.1"
version "2.15.1"
resolved "https://registry.yarnpkg.com/commander/-/commander-2.15.1.tgz#df46e867d0fc2aec66a34662b406a9ccafff5b0f"
resolved "https://registry.yarnpkg.com/commander/-/commander-2.15.1.tgz#df46e867d0fc2aec66a34662b406a9ccafff5b0f"
commander@^2.12.1:
version "2.16.0"
resolved "https://registry.yarnpkg.com/commander/-/commander-2.16.0.tgz#f16390593996ceb4f3eeb020b31d78528f7f8a50"
commander@~2.17.1:
commander@~2.17.1:
version "2.17.1"
version "2.17.1"
resolved "https://registry.yarnpkg.com/commander/-/commander-2.17.1.tgz#bd77ab7de6de94205ceacc72f1716d29f20a77bf"
resolved "https://registry.yarnpkg.com/commander/-/commander-2.17.1.tgz#bd77ab7de6de94205ceacc72f1716d29f20a77bf"
...
@@ -1134,7 +1130,7 @@ debug@^4.0.1, debug@^4.1.0, debug@^4.1.1:
...
@@ -1134,7 +1130,7 @@ debug@^4.0.1, debug@^4.1.0, debug@^4.1.1:
dependencies:
dependencies:
ms "^2.1.1"
ms "^2.1.1"
debuglog@*,
debuglog@^1.0.1:
debuglog@^1.0.1:
version "1.0.1"
version "1.0.1"
resolved "https://registry.yarnpkg.com/debuglog/-/debuglog-1.0.1.tgz#aa24ffb9ac3df9a2351837cfb2d279360cd78492"
resolved "https://registry.yarnpkg.com/debuglog/-/debuglog-1.0.1.tgz#aa24ffb9ac3df9a2351837cfb2d279360cd78492"
...
@@ -1217,7 +1213,7 @@ dezalgo@^1.0.0, dezalgo@~1.0.3:
...
@@ -1217,7 +1213,7 @@ dezalgo@^1.0.0, dezalgo@~1.0.3:
asap "^2.0.0"
asap "^2.0.0"
wrappy "1"
wrappy "1"
diff@3.5.0, diff@^3.1.0
, diff@^3.2.0
:
diff@3.5.0, diff@^3.1.0:
version "3.5.0"
version "3.5.0"
resolved "https://registry.yarnpkg.com/diff/-/diff-3.5.0.tgz#800c0dd1e0a8bfbc95835c202ad220fe317e5a12"
resolved "https://registry.yarnpkg.com/diff/-/diff-3.5.0.tgz#800c0dd1e0a8bfbc95835c202ad220fe317e5a12"
...
@@ -2080,7 +2076,7 @@ import-lazy@^2.1.0:
...
@@ -2080,7 +2076,7 @@ import-lazy@^2.1.0:
version "2.1.0"
version "2.1.0"
resolved "https://registry.yarnpkg.com/import-lazy/-/import-lazy-2.1.0.tgz#05698e3d45c88e8d7e9d92cb0584e77f096f3e43"
resolved "https://registry.yarnpkg.com/import-lazy/-/import-lazy-2.1.0.tgz#05698e3d45c88e8d7e9d92cb0584e77f096f3e43"
imurmurhash@*,
imurmurhash@^0.1.4:
imurmurhash@^0.1.4:
version "0.1.4"
version "0.1.4"
resolved "https://registry.yarnpkg.com/imurmurhash/-/imurmurhash-0.1.4.tgz#9218b9b2b928a238b13dc4fb6b6d576f231453ea"
resolved "https://registry.yarnpkg.com/imurmurhash/-/imurmurhash-0.1.4.tgz#9218b9b2b928a238b13dc4fb6b6d576f231453ea"
...
@@ -2519,10 +2515,6 @@ lockfile@~1.0.3:
...
@@ -2519,10 +2515,6 @@ lockfile@~1.0.3:
dependencies:
dependencies:
signal-exit "^3.0.2"
signal-exit "^3.0.2"
lodash._baseindexof@*:
version "3.1.0"
resolved "https://registry.yarnpkg.com/lodash._baseindexof/-/lodash._baseindexof-3.1.0.tgz#fe52b53a1c6761e42618d654e4a25789ed61822c"
lodash._baseuniq@~4.6.0:
lodash._baseuniq@~4.6.0:
version "4.6.0"
version "4.6.0"
resolved "https://registry.yarnpkg.com/lodash._baseuniq/-/lodash._baseuniq-4.6.0.tgz#0ebb44e456814af7905c6212fa2c9b2d51b841e8"
resolved "https://registry.yarnpkg.com/lodash._baseuniq/-/lodash._baseuniq-4.6.0.tgz#0ebb44e456814af7905c6212fa2c9b2d51b841e8"
...
@@ -2530,28 +2522,10 @@ lodash._baseuniq@~4.6.0:
...
@@ -2530,28 +2522,10 @@ lodash._baseuniq@~4.6.0:
lodash._createset "~4.0.0"
lodash._createset "~4.0.0"
lodash._root "~3.0.0"
lodash._root "~3.0.0"
lodash._bindcallback@*:
version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._bindcallback/-/lodash._bindcallback-3.0.1.tgz#e531c27644cf8b57a99e17ed95b35c748789392e"
lodash._cacheindexof@*:
version "3.0.2"
resolved "https://registry.yarnpkg.com/lodash._cacheindexof/-/lodash._cacheindexof-3.0.2.tgz#3dc69ac82498d2ee5e3ce56091bafd2adc7bde92"
lodash._createcache@*:
version "3.1.2"
resolved "https://registry.yarnpkg.com/lodash._createcache/-/lodash._createcache-3.1.2.tgz#56d6a064017625e79ebca6b8018e17440bdcf093"
dependencies:
lodash._getnative "^3.0.0"
lodash._createset@~4.0.0:
lodash._createset@~4.0.0:
version "4.0.3"
version "4.0.3"
resolved "https://registry.yarnpkg.com/lodash._createset/-/lodash._createset-4.0.3.tgz#0f4659fbb09d75194fa9e2b88a6644d363c9fe26"
resolved "https://registry.yarnpkg.com/lodash._createset/-/lodash._createset-4.0.3.tgz#0f4659fbb09d75194fa9e2b88a6644d363c9fe26"
lodash._getnative@*, lodash._getnative@^3.0.0:
version "3.9.1"
resolved "https://registry.yarnpkg.com/lodash._getnative/-/lodash._getnative-3.9.1.tgz#570bc7dede46d61cdcde687d65d3eecbaa3aaff5"
lodash._root@~3.0.0:
lodash._root@~3.0.0:
version "3.0.1"
version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._root/-/lodash._root-3.0.1.tgz#fba1c4524c19ee9a5f8136b4609f017cf4ded692"
resolved "https://registry.yarnpkg.com/lodash._root/-/lodash._root-3.0.1.tgz#fba1c4524c19ee9a5f8136b4609f017cf4ded692"
...
@@ -2600,10 +2574,6 @@ lodash.pick@^4.4.0:
...
@@ -2600,10 +2574,6 @@ lodash.pick@^4.4.0:
version "4.4.0"
version "4.4.0"
resolved "https://registry.yarnpkg.com/lodash.pick/-/lodash.pick-4.4.0.tgz#52f05610fff9ded422611441ed1fc123a03001b3"
resolved "https://registry.yarnpkg.com/lodash.pick/-/lodash.pick-4.4.0.tgz#52f05610fff9ded422611441ed1fc123a03001b3"
lodash.restparam@*:
version "3.6.1"
resolved "https://registry.yarnpkg.com/lodash.restparam/-/lodash.restparam-3.6.1.tgz#936a4e309ef330a7645ed4145986c85ae5b20805"
lodash.unescape@4.0.1:
lodash.unescape@4.0.1:
version "4.0.1"
version "4.0.1"
resolved "https://registry.yarnpkg.com/lodash.unescape/-/lodash.unescape-4.0.1.tgz#bf2249886ce514cda112fae9218cdc065211fc9c"
resolved "https://registry.yarnpkg.com/lodash.unescape/-/lodash.unescape-4.0.1.tgz#bf2249886ce514cda112fae9218cdc065211fc9c"
...
@@ -3519,10 +3489,6 @@ path-key@^2.0.0, path-key@^2.0.1:
...
@@ -3519,10 +3489,6 @@ path-key@^2.0.0, path-key@^2.0.1:
version "2.0.1"
version "2.0.1"
resolved "https://registry.yarnpkg.com/path-key/-/path-key-2.0.1.tgz#411cadb574c5a140d3a4b1910d40d80cc9f40b40"
resolved "https://registry.yarnpkg.com/path-key/-/path-key-2.0.1.tgz#411cadb574c5a140d3a4b1910d40d80cc9f40b40"
path-parse@^1.0.5:
version "1.0.5"
resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.5.tgz#3c1adf871ea9cd6c9431b6ea2bd74a0ff055c4c1"
path-parse@^1.0.6:
path-parse@^1.0.6:
version "1.0.6"
version "1.0.6"
resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.6.tgz#d62dbb5679405d72c4737ec58600e9ddcf06d24c"
resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.6.tgz#d62dbb5679405d72c4737ec58600e9ddcf06d24c"
...
@@ -3834,7 +3800,7 @@ readable-stream@~2.0.0:
...
@@ -3834,7 +3800,7 @@ readable-stream@~2.0.0:
string_decoder "~0.10.x"
string_decoder "~0.10.x"
util-deprecate "~1.0.1"
util-deprecate "~1.0.1"
readdir-scoped-modules@*,
readdir-scoped-modules@^1.0.0:
readdir-scoped-modules@^1.0.0:
version "1.1.0"
version "1.1.0"
resolved "https://registry.yarnpkg.com/readdir-scoped-modules/-/readdir-scoped-modules-1.1.0.tgz#8d45407b4f870a0dcaebc0e28670d18e74514309"
resolved "https://registry.yarnpkg.com/readdir-scoped-modules/-/readdir-scoped-modules-1.1.0.tgz#8d45407b4f870a0dcaebc0e28670d18e74514309"
dependencies:
dependencies:
...
@@ -3977,12 +3943,6 @@ resolve@^1.10.0:
...
@@ -3977,12 +3943,6 @@ resolve@^1.10.0:
dependencies:
dependencies:
path-parse "^1.0.6"
path-parse "^1.0.6"
resolve@^1.3.2:
version "1.8.1"
resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.8.1.tgz#82f1ec19a423ac1fbd080b0bab06ba36e84a7a26"
dependencies:
path-parse "^1.0.5"
responselike@1.0.2:
responselike@1.0.2:
version "1.0.2"
version "1.0.2"
resolved "https://registry.yarnpkg.com/responselike/-/responselike-1.0.2.tgz#918720ef3b631c5642be068f15ade5a46f4ba1e7"
resolved "https://registry.yarnpkg.com/responselike/-/responselike-1.0.2.tgz#918720ef3b631c5642be068f15ade5a46f4ba1e7"
...
@@ -4599,7 +4559,7 @@ ts-node@^7.0.0:
...
@@ -4599,7 +4559,7 @@ ts-node@^7.0.0:
source-map-support "^0.5.6"
source-map-support "^0.5.6"
yn "^2.0.0"
yn "^2.0.0"
tslib@^1.8.0,
tslib@^1.8.1:
tslib@^1.8.1:
version "1.9.3"
version "1.9.3"
resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.9.3.tgz#d7e4dd79245d85428c4d7e4822a79917954ca286"
resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.9.3.tgz#d7e4dd79245d85428c4d7e4822a79917954ca286"
...
@@ -4607,42 +4567,6 @@ tslib@^1.9.0:
...
@@ -4607,42 +4567,6 @@ tslib@^1.9.0:
version "1.10.0"
version "1.10.0"
resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.10.0.tgz#c3c19f95973fb0a62973fb09d90d961ee43e5c8a"
resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.10.0.tgz#c3c19f95973fb0a62973fb09d90d961ee43e5c8a"
tslint-microsoft-contrib@^6.0.0:
version "6.2.0"
resolved "https://registry.yarnpkg.com/tslint-microsoft-contrib/-/tslint-microsoft-contrib-6.2.0.tgz#8aa0f40584d066d05e6a5e7988da5163b85f2ad4"
dependencies:
tsutils "^2.27.2 <2.29.0"
tslint@^5.12.0:
version "5.18.0"
resolved "https://registry.yarnpkg.com/tslint/-/tslint-5.18.0.tgz#f61a6ddcf372344ac5e41708095bbf043a147ac6"
dependencies:
"@babel/code-frame" "^7.0.0"
builtin-modules "^1.1.1"
chalk "^2.3.0"
commander "^2.12.1"
diff "^3.2.0"
glob "^7.1.1"
js-yaml "^3.13.1"
minimatch "^3.0.4"
mkdirp "^0.5.1"
resolve "^1.3.2"
semver "^5.3.0"
tslib "^1.8.0"
tsutils "^2.29.0"
"tsutils@^2.27.2 <2.29.0":
version "2.28.0"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-2.28.0.tgz#6bd71e160828f9d019b6f4e844742228f85169a1"
dependencies:
tslib "^1.8.1"
tsutils@^2.29.0:
version "2.29.0"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-2.29.0.tgz#32b488501467acbedd4b85498673a0812aca0b99"
dependencies:
tslib "^1.8.1"
tsutils@^3.17.1:
tsutils@^3.17.1:
version "3.17.1"
version "3.17.1"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-3.17.1.tgz#ed719917f11ca0dee586272b2ac49e015a2dd759"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-3.17.1.tgz#ed719917f11ca0dee586272b2ac49e015a2dd759"
...
@@ -4818,7 +4742,7 @@ v8-compile-cache@^2.0.3:
...
@@ -4818,7 +4742,7 @@ v8-compile-cache@^2.0.3:
version "2.1.0"
version "2.1.0"
resolved "https://registry.yarnpkg.com/v8-compile-cache/-/v8-compile-cache-2.1.0.tgz#e14de37b31a6d194f5690d67efc4e7f6fc6ab30e"
resolved "https://registry.yarnpkg.com/v8-compile-cache/-/v8-compile-cache-2.1.0.tgz#e14de37b31a6d194f5690d67efc4e7f6fc6ab30e"
validate-npm-package-license@*,
validate-npm-package-license@^3.0.1:
validate-npm-package-license@^3.0.1:
version "3.0.4"
version "3.0.4"
resolved "https://registry.yarnpkg.com/validate-npm-package-license/-/validate-npm-package-license-3.0.4.tgz#fc91f6b9c7ba15c857f4cb2c5defeec39d4f410a"
resolved "https://registry.yarnpkg.com/validate-npm-package-license/-/validate-npm-package-license-3.0.4.tgz#fc91f6b9c7ba15c857f4cb2c5defeec39d4f410a"
dependencies:
dependencies:
...
...
src/sdk/pynni/nni/batch_tuner/batch_tuner.py
View file @
543239c6
...
@@ -24,13 +24,15 @@ class BatchTuner(Tuner):
...
@@ -24,13 +24,15 @@ class BatchTuner(Tuner):
Examples
Examples
--------
--------
The search space only be accepted like:
The search space only be accepted like:
```
{
::
'combine_params': { '_type': 'choice',
'_value': '[{...}, {...}, {...}]',
{'combine_params':
}
{ '_type': 'choice',
}
'_value': '[{...}, {...}, {...}]',
```
}
}
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
...
...
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
View file @
543239c6
...
@@ -5,7 +5,7 @@ import logging
...
@@ -5,7 +5,7 @@ import logging
import
torch
import
torch
from
.compressor
import
Pruner
from
.compressor
import
Pruner
__all__
=
[
'LevelPruner'
,
'AGP_Pruner'
,
'
FPGM
Pruner'
,
'L1FilterPruner'
,
'
Slim
Pruner'
]
__all__
=
[
'LevelPruner'
,
'AGP_Pruner'
,
'
Slim
Pruner'
,
'L1FilterPruner'
,
'
L2FilterPruner'
,
'FPGM
Pruner'
]
logger
=
logging
.
getLogger
(
'torch pruner'
)
logger
=
logging
.
getLogger
(
'torch pruner'
)
...
@@ -166,119 +166,132 @@ class AGP_Pruner(Pruner):
...
@@ -166,119 +166,132 @@ class AGP_Pruner(Pruner):
self
.
if_init_list
[
k
]
=
True
self
.
if_init_list
[
k
]
=
True
class
FPGM
Pruner
(
Pruner
):
class
Slim
Pruner
(
Pruner
):
"""
"""
A filter pruner via geometric median.
A structured pruning algorithm that prunes channels by pruning the weights of BN layers.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration",
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang
https://arxiv.org/pdf/1811.00250.pdf
"Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
"""
"""
def
__init__
(
self
,
model
,
config_list
):
def
__init__
(
self
,
model
,
config_list
):
"""
"""
Parameters
Parameters
----------
----------
model : pytorch model
config_list : list
the model user wants to compress
config_list: list
support key for each list item:
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
- sparsity: percentage of convolutional filters to be pruned.
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
self
.
mask_dict
=
{}
self
.
mask_calculated_ops
=
set
()
self
.
epoch_pruned_layers
=
set
()
weight_list
=
[]
if
len
(
config_list
)
>
1
:
logger
.
warning
(
'Slim pruner only supports 1 configuration'
)
config
=
config_list
[
0
]
for
(
layer
,
config
)
in
self
.
detect_modules_to_compress
():
assert
layer
.
type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
weight_list
.
append
(
layer
.
module
.
weight
.
data
.
abs
().
clone
())
all_bn_weights
=
torch
.
cat
(
weight_list
)
k
=
int
(
all_bn_weights
.
shape
[
0
]
*
config
[
'sparsity'
])
self
.
global_threshold
=
torch
.
topk
(
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
):
"""
"""
Supports Conv1d, Conv2d
Calculate the mask of given layer.
filter dimensions for Conv1d:
Scale factors with the smallest absolute value in the BN layer are masked.
OUT: number of output channel
IN: number of input channel
LEN: filter length
filter dimensions for Conv2d:
OUT: number of output channel
IN: number of input channel
H: filter height
W: filter width
Parameters
Parameters
----------
----------
layer : LayerInfo
layer : LayerInfo
calculate mask for `layer`'s weight
the layer to instrument the compression operation
config : dict
config : dict
the configuration for generating the mask
layer's pruning config
Returns
-------
torch.Tensor
mask of the layer's weight
"""
"""
weight
=
layer
.
module
.
weight
.
data
weight
=
layer
.
module
.
weight
.
data
assert
0
<=
config
.
get
(
'sparsity'
)
<
1
op_name
=
layer
.
name
assert
layer
.
type
in
[
'Conv1d'
,
'Conv2d'
]
op_type
=
layer
.
type
assert
layer
.
type
in
config
[
'op_types'
]
assert
op_type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
if
op_name
in
self
.
mask_calculated_ops
:
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
mask
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
)
try
:
w_abs
=
weight
.
abs
()
mask
=
torch
.
gt
(
w_abs
,
self
.
global_threshold
).
type_as
(
weight
)
finally
:
self
.
mask_dict
.
update
({
layer
.
name
:
mask
})
self
.
mask_calculated_ops
.
add
(
layer
.
name
)
if
layer
.
name
in
self
.
epoch_pruned_layers
:
return
mask
assert
layer
.
name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
layer
.
name
)
masks
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
)
try
:
class
RankFilterPruner
(
Pruner
):
num_filters
=
weight
.
size
(
0
)
"""
num_prune
=
int
(
num_filters
*
config
.
get
(
'sparsity'
))
A structured pruning base class that prunes the filters with the smallest
if
num_filters
<
2
or
num_prune
<
1
:
importance criterion in convolution layers to achieve a preset level of network sparsity.
return
masks
"""
min_gm_idx
=
self
.
_get_min_gm_kernel_idx
(
weight
,
num_prune
)
for
idx
in
min_gm_idx
:
masks
[
idx
]
=
0.
finally
:
self
.
mask_dict
.
update
({
layer
.
name
:
masks
})
self
.
epoch_pruned_layers
.
add
(
layer
.
name
)
return
masks
def
__init__
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
def
_get_min_gm_kernel_idx
(
s
el
f
,
weight
,
n
):
super
().
__init__
(
mod
el
,
config_list
)
assert
len
(
weight
.
size
())
in
[
3
,
4
]
self
.
mask_calculated_ops
=
set
()
dist_list
=
[]
def
_get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
for
out_i
in
range
(
weight
.
size
(
0
)):
return
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
)
dist_sum
=
self
.
_get_distance_sum
(
weight
,
out_i
)
dist_list
.
append
((
dist_sum
,
out_i
))
min_gm_kernels
=
sorted
(
dist_list
,
key
=
lambda
x
:
x
[
0
])[:
n
]
return
[
x
[
1
]
for
x
in
min_gm_kernels
]
def
_get_distance_sum
(
self
,
weight
,
out_idx
):
def
calc_mask
(
self
,
layer
,
config
):
"""
"""
Calculate the total distance between a specified filter (by out_idex and in_idx) and
Calculate the mask of given layer.
all other filters.
Filters with the smallest importance criterion of the kernel weights are masked.
Optimized verision of following naive implementation:
def _get_distance_sum(self, weight, in_idx, out_idx):
w = weight.view(-1, weight.size(-2), weight.size(-1))
dist_sum = 0.
for k in w:
dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2)
return dist_sum
Parameters
Parameters
----------
----------
weight: Tensor
layer : LayerInfo
convolutional filter weight
the layer to instrument the compression operation
out_idx: int
config : dict
output channel index of specified filter, this method calculates the total distance
layer's pruning config
between this specified filter and all other filters.
Returns
Returns
-------
-------
float32
torch.Tensor
The total distance
mask of the layer's weight
"""
"""
logger
.
debug
(
'weight size: %s'
,
weight
.
size
())
assert
len
(
weight
.
size
())
in
[
3
,
4
],
'unsupported weight shape'
w
=
weight
.
view
(
weight
.
size
(
0
),
-
1
)
anchor_w
=
w
[
out_idx
].
unsqueeze
(
0
).
expand
(
w
.
size
(
0
),
w
.
size
(
1
))
x
=
w
-
anchor_w
x
=
(
x
*
x
).
sum
(
-
1
)
x
=
torch
.
sqrt
(
x
)
return
x
.
sum
()
def
update_epoch
(
self
,
epoch
):
weight
=
layer
.
module
.
weight
.
data
self
.
epoch_pruned_layers
=
set
()
op_name
=
layer
.
name
op_type
=
layer
.
type
assert
0
<=
config
.
get
(
'sparsity'
)
<
1
assert
op_type
in
[
'Conv1d'
,
'Conv2d'
]
assert
op_type
in
config
.
get
(
'op_types'
)
if
op_name
in
self
.
mask_calculated_ops
:
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
mask
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
)
try
:
filters
=
weight
.
size
(
0
)
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
if
filters
<
2
or
num_prune
<
1
:
return
mask
mask
=
self
.
_get_mask
(
mask
,
weight
,
num_prune
)
finally
:
self
.
mask_dict
.
update
({
op_name
:
mask
})
self
.
mask_calculated_ops
.
add
(
op_name
)
return
mask
.
detach
()
class
L1FilterPruner
(
Pruner
):
class
L1FilterPruner
(
RankFilter
Pruner
):
"""
"""
A structured pruning algorithm that prunes the filters of smallest magnitude
A structured pruning algorithm that prunes the filters of smallest magnitude
weights sum in the convolution layers to achieve a preset level of network sparsity.
weights sum in the convolution layers to achieve a preset level of network sparsity.
...
@@ -299,107 +312,162 @@ class L1FilterPruner(Pruner):
...
@@ -299,107 +312,162 @@ class L1FilterPruner(Pruner):
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
def
calc
_mask
(
self
,
layer
,
config
):
def
_get
_mask
(
self
,
base_mask
,
weight
,
num_prune
):
"""
"""
Calculate the mask of given layer.
Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked.
Filters with the smallest sum of its absolute kernel weights are masked.
Parameters
Parameters
----------
----------
layer : LayerInfo
base_mask : torch.Tensor
the layer to instrument the compression operation
The basic mask with the same shape of weight, all item in the basic mask is 1.
config : dict
weight : torch.Tensor
layer's pruning config
Layer's weight
num_prune : int
Num of filters to prune
Returns
Returns
-------
-------
torch.Tensor
torch.Tensor
m
ask of the layer's weight
M
ask of the layer's weight
"""
"""
weight
=
layer
.
module
.
weight
.
data
filters
=
weight
.
shape
[
0
]
op_name
=
layer
.
name
w_abs
=
weight
.
abs
()
op_type
=
layer
.
type
w_abs_structured
=
w_abs
.
view
(
filters
,
-
1
).
sum
(
dim
=
1
)
assert
op_type
==
'Conv2d'
,
'L1FilterPruner only supports 2d convolution layer pruning'
threshold
=
torch
.
topk
(
w_abs_structured
.
view
(
-
1
),
num_prune
,
largest
=
False
)[
0
].
max
()
if
op_name
in
self
.
mask_calculated_ops
:
mask
=
torch
.
gt
(
w_abs_structured
,
threshold
)[:,
None
,
None
,
None
].
expand_as
(
weight
).
type_as
(
weight
)
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
mask
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
)
try
:
filters
=
weight
.
shape
[
0
]
w_abs
=
weight
.
abs
()
k
=
int
(
filters
*
config
[
'sparsity'
])
if
k
==
0
:
return
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
w_abs_structured
=
w_abs
.
view
(
filters
,
-
1
).
sum
(
dim
=
1
)
threshold
=
torch
.
topk
(
w_abs_structured
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
mask
=
torch
.
gt
(
w_abs_structured
,
threshold
)[:,
None
,
None
,
None
].
expand_as
(
weight
).
type_as
(
weight
)
finally
:
self
.
mask_dict
.
update
({
layer
.
name
:
mask
})
self
.
mask_calculated_ops
.
add
(
layer
.
name
)
return
mask
return
mask
class
SlimPruner
(
Pruner
):
class
L2FilterPruner
(
RankFilter
Pruner
):
"""
"""
A structured pruning algorithm that prunes channels by pruning the weights of BN layers.
A structured pruning algorithm that prunes the filters with the
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang
smallest L2 norm of the absolute kernel weights are masked.
"Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
"""
"""
def
__init__
(
self
,
model
,
config_list
):
def
__init__
(
self
,
model
,
config_list
):
"""
"""
Parameters
Parameters
----------
----------
model : torch.nn.module
Model to be pruned
config_list : list
config_list : list
support key for each list item:
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
- sparsity: percentage of convolutional filters to be pruned.
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
weight_list
=
[]
if
len
(
config_list
)
>
1
:
logger
.
warning
(
'Slim pruner only supports 1 configuration'
)
config
=
config_list
[
0
]
for
(
layer
,
config
)
in
self
.
detect_modules_to_compress
():
assert
layer
.
type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
weight_list
.
append
(
layer
.
module
.
weight
.
data
.
abs
().
clone
())
all_bn_weights
=
torch
.
cat
(
weight_list
)
k
=
int
(
all_bn_weights
.
shape
[
0
]
*
config
[
'sparsity'
])
self
.
global_threshold
=
torch
.
topk
(
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
def
calc
_mask
(
self
,
layer
,
config
):
def
_get
_mask
(
self
,
base_mask
,
weight
,
num_prune
):
"""
"""
Calculate the mask of given layer.
Calculate the mask of given layer.
Scale facto
rs with the smallest
absolute value in the BN layer
are masked.
Filte
rs with the smallest
L2 norm of the absolute kernel weights
are masked.
Parameters
Parameters
----------
----------
layer : LayerInfo
base_mask : torch.Tensor
the layer to instrument the compression operation
The basic mask with the same shape of weight, all item in the basic mask is 1.
config : dict
weight : torch.Tensor
layer's pruning config
Layer's weight
num_prune : int
Num of filters to prune
Returns
Returns
-------
-------
torch.Tensor
torch.Tensor
m
ask of the layer's weight
M
ask of the layer's weight
"""
"""
filters
=
weight
.
shape
[
0
]
weight
=
layer
.
module
.
weight
.
data
w
=
weight
.
view
(
filters
,
-
1
)
op_name
=
layer
.
name
w_l2_norm
=
torch
.
sqrt
((
w
**
2
).
sum
(
dim
=
1
))
op_type
=
layer
.
type
threshold
=
torch
.
topk
(
w_l2_norm
.
view
(
-
1
),
num_prune
,
largest
=
False
)[
0
].
max
()
assert
op_type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
mask
=
torch
.
gt
(
w_l2_norm
,
threshold
)[:,
None
,
None
,
None
].
expand_as
(
weight
).
type_as
(
weight
)
if
op_name
in
self
.
mask_calculated_ops
:
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
mask
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
)
try
:
w_abs
=
weight
.
abs
()
mask
=
torch
.
gt
(
w_abs
,
self
.
global_threshold
).
type_as
(
weight
)
finally
:
self
.
mask_dict
.
update
({
layer
.
name
:
mask
})
self
.
mask_calculated_ops
.
add
(
layer
.
name
)
return
mask
return
mask
class
FPGMPruner
(
RankFilterPruner
):
"""
A filter pruner via geometric median.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration",
https://arxiv.org/pdf/1811.00250.pdf
"""
def
__init__
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : pytorch model
the model user wants to compress
config_list: list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super
().
__init__
(
model
,
config_list
)
def
_get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
"""
Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked.
Parameters
----------
base_mask : torch.Tensor
The basic mask with the same shape of weight, all item in the basic mask is 1.
weight : torch.Tensor
Layer's weight
num_prune : int
Num of filters to prune
Returns
-------
torch.Tensor
Mask of the layer's weight
"""
min_gm_idx
=
self
.
_get_min_gm_kernel_idx
(
weight
,
num_prune
)
for
idx
in
min_gm_idx
:
base_mask
[
idx
]
=
0.
return
base_mask
def
_get_min_gm_kernel_idx
(
self
,
weight
,
n
):
assert
len
(
weight
.
size
())
in
[
3
,
4
]
dist_list
=
[]
for
out_i
in
range
(
weight
.
size
(
0
)):
dist_sum
=
self
.
_get_distance_sum
(
weight
,
out_i
)
dist_list
.
append
((
dist_sum
,
out_i
))
min_gm_kernels
=
sorted
(
dist_list
,
key
=
lambda
x
:
x
[
0
])[:
n
]
return
[
x
[
1
]
for
x
in
min_gm_kernels
]
def
_get_distance_sum
(
self
,
weight
,
out_idx
):
"""
Calculate the total distance between a specified filter (by out_idex and in_idx) and
all other filters.
Optimized verision of following naive implementation:
def _get_distance_sum(self, weight, in_idx, out_idx):
w = weight.view(-1, weight.size(-2), weight.size(-1))
dist_sum = 0.
for k in w:
dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2)
return dist_sum
Parameters
----------
weight: Tensor
convolutional filter weight
out_idx: int
output channel index of specified filter, this method calculates the total distance
between this specified filter and all other filters.
Returns
-------
float32
The total distance
"""
logger
.
debug
(
'weight size: %s'
,
weight
.
size
())
assert
len
(
weight
.
size
())
in
[
3
,
4
],
'unsupported weight shape'
w
=
weight
.
view
(
weight
.
size
(
0
),
-
1
)
anchor_w
=
w
[
out_idx
].
unsqueeze
(
0
).
expand
(
w
.
size
(
0
),
w
.
size
(
1
))
x
=
w
-
anchor_w
x
=
(
x
*
x
).
sum
(
-
1
)
x
=
torch
.
sqrt
(
x
)
return
x
.
sum
()
def
update_epoch
(
self
,
epoch
):
self
.
mask_calculated_ops
=
set
()
src/sdk/pynni/nni/msg_dispatcher_base.py
View file @
543239c6
...
@@ -163,19 +163,23 @@ class MsgDispatcherBase(Recoverable):
...
@@ -163,19 +163,23 @@ class MsgDispatcherBase(Recoverable):
raise
NotImplementedError
(
'handle_initialize not implemented'
)
raise
NotImplementedError
(
'handle_initialize not implemented'
)
def
handle_request_trial_jobs
(
self
,
data
):
def
handle_request_trial_jobs
(
self
,
data
):
"""The message dispatcher is demanded to generate `data` trial jobs.
"""The message dispatcher is demanded to generate
`
`data`
`
trial jobs.
These trial jobs should be sent via `send(CommandType.NewTrialJob, json_tricks.dumps(parameter))`,
These trial jobs should be sent via
`
`send(CommandType.NewTrialJob, json_tricks.dumps(parameter))`
`
,
where `parameter` will be received by NNI Manager and eventually accessible to trial jobs as "next parameter".
where
`
`parameter`
`
will be received by NNI Manager and eventually accessible to trial jobs as "next parameter".
Semantically, message dispatcher should do this `send` exactly `data` times.
Semantically, message dispatcher should do this
`
`send`
`
exactly
`
`data`
`
times.
The JSON sent by this method should follow the format of
The JSON sent by this method should follow the format of
{
"parameter_id": 42
::
"parameters": {
// this will be received by trial
{
},
"parameter_id": 42
"parameter_source": "algorithm" // optional
"parameters": {
}
// this will be received by trial
},
"parameter_source": "algorithm" // optional
}
Parameters
Parameters
----------
----------
data: int
data: int
...
@@ -211,6 +215,7 @@ class MsgDispatcherBase(Recoverable):
...
@@ -211,6 +215,7 @@ class MsgDispatcherBase(Recoverable):
def
handle_report_metric_data
(
self
,
data
):
def
handle_report_metric_data
(
self
,
data
):
"""Called when metric data is reported or new parameters are requested (for multiphase).
"""Called when metric data is reported or new parameters are requested (for multiphase).
When new parameters are requested, this method should send a new parameter.
When new parameters are requested, this method should send a new parameter.
Parameters
Parameters
----------
----------
data: dict
data: dict
...
@@ -219,6 +224,7 @@ class MsgDispatcherBase(Recoverable):
...
@@ -219,6 +224,7 @@ class MsgDispatcherBase(Recoverable):
`REQUEST_PARAMETER` is used to request new parameters for multiphase trial job. In this case,
`REQUEST_PARAMETER` is used to request new parameters for multiphase trial job. In this case,
the dict will contain additional keys: `trial_job_id`, `parameter_index`. Refer to `msg_dispatcher.py`
the dict will contain additional keys: `trial_job_id`, `parameter_index`. Refer to `msg_dispatcher.py`
as an example.
as an example.
Raises
Raises
------
------
ValueError
ValueError
...
@@ -228,6 +234,7 @@ class MsgDispatcherBase(Recoverable):
...
@@ -228,6 +234,7 @@ class MsgDispatcherBase(Recoverable):
def
handle_trial_end
(
self
,
data
):
def
handle_trial_end
(
self
,
data
):
"""Called when the state of one of the trials is changed
"""Called when the state of one of the trials is changed
Parameters
Parameters
----------
----------
data: dict
data: dict
...
@@ -235,5 +242,6 @@ class MsgDispatcherBase(Recoverable):
...
@@ -235,5 +242,6 @@ class MsgDispatcherBase(Recoverable):
trial_job_id: the id generated by training service.
trial_job_id: the id generated by training service.
event: the job’s state.
event: the job’s state.
hyper_params: the string that is sent by message dispatcher during the creation of trials.
hyper_params: the string that is sent by message dispatcher during the creation of trials.
"""
"""
raise
NotImplementedError
(
'handle_trial_end not implemented'
)
raise
NotImplementedError
(
'handle_trial_end not implemented'
)
src/sdk/pynni/tests/test_compressor.py
View file @
543239c6
...
@@ -58,8 +58,9 @@ def tf2(func):
...
@@ -58,8 +58,9 @@ def tf2(func):
return
test_tf2_func
return
test_tf2_func
# for fpgm filter pruner test
# for fpgm filter pruner test
w
=
np
.
array
([[[[
i
+
1
]
*
3
]
*
3
]
*
5
for
i
in
range
(
10
)])
w
=
np
.
array
([[[[
i
+
1
]
*
3
]
*
3
]
*
5
for
i
in
range
(
10
)])
class
CompressorTestCase
(
TestCase
):
class
CompressorTestCase
(
TestCase
):
...
@@ -69,19 +70,19 @@ class CompressorTestCase(TestCase):
...
@@ -69,19 +70,19 @@ class CompressorTestCase(TestCase):
config_list
=
[{
config_list
=
[{
'quant_types'
:
[
'weight'
],
'quant_types'
:
[
'weight'
],
'quant_bits'
:
8
,
'quant_bits'
:
8
,
'op_types'
:[
'Conv2d'
,
'Linear'
]
'op_types'
:
[
'Conv2d'
,
'Linear'
]
},
{
},
{
'quant_types'
:
[
'output'
],
'quant_types'
:
[
'output'
],
'quant_bits'
:
8
,
'quant_bits'
:
8
,
'quant_start_step'
:
0
,
'quant_start_step'
:
0
,
'op_types'
:[
'ReLU'
]
'op_types'
:
[
'ReLU'
]
}]
}]
model
.
relu
=
torch
.
nn
.
ReLU
()
model
.
relu
=
torch
.
nn
.
ReLU
()
quantizer
=
torch_compressor
.
QAT_Quantizer
(
model
,
config_list
)
quantizer
=
torch_compressor
.
QAT_Quantizer
(
model
,
config_list
)
quantizer
.
compress
()
quantizer
.
compress
()
modules_to_compress
=
quantizer
.
get_modules_to_compress
()
modules_to_compress
=
quantizer
.
get_modules_to_compress
()
modules_to_compress_name
=
[
t
[
0
].
name
for
t
in
modules_to_compress
]
modules_to_compress_name
=
[
t
[
0
].
name
for
t
in
modules_to_compress
]
assert
"conv1"
in
modules_to_compress_name
assert
"conv1"
in
modules_to_compress_name
assert
"conv2"
in
modules_to_compress_name
assert
"conv2"
in
modules_to_compress_name
assert
"fc1"
in
modules_to_compress_name
assert
"fc1"
in
modules_to_compress_name
...
@@ -179,7 +180,8 @@ class CompressorTestCase(TestCase):
...
@@ -179,7 +180,8 @@ class CompressorTestCase(TestCase):
w
=
np
.
array
([
np
.
zeros
((
3
,
3
,
3
)),
np
.
ones
((
3
,
3
,
3
)),
np
.
ones
((
3
,
3
,
3
))
*
2
,
w
=
np
.
array
([
np
.
zeros
((
3
,
3
,
3
)),
np
.
ones
((
3
,
3
,
3
)),
np
.
ones
((
3
,
3
,
3
))
*
2
,
np
.
ones
((
3
,
3
,
3
))
*
3
,
np
.
ones
((
3
,
3
,
3
))
*
4
])
np
.
ones
((
3
,
3
,
3
))
*
3
,
np
.
ones
((
3
,
3
,
3
))
*
4
])
model
=
TorchModel
()
model
=
TorchModel
()
config_list
=
[{
'sparsity'
:
0.2
,
'op_names'
:
[
'conv1'
]},
{
'sparsity'
:
0.6
,
'op_names'
:
[
'conv2'
]}]
config_list
=
[{
'sparsity'
:
0.2
,
'op_types'
:
[
'Conv2d'
],
'op_names'
:
[
'conv1'
]},
{
'sparsity'
:
0.6
,
'op_types'
:
[
'Conv2d'
],
'op_names'
:
[
'conv2'
]}]
pruner
=
torch_compressor
.
L1FilterPruner
(
model
,
config_list
)
pruner
=
torch_compressor
.
L1FilterPruner
(
model
,
config_list
)
model
.
conv1
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
model
.
conv1
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
...
@@ -236,12 +238,12 @@ class CompressorTestCase(TestCase):
...
@@ -236,12 +238,12 @@ class CompressorTestCase(TestCase):
config_list
=
[{
config_list
=
[{
'quant_types'
:
[
'weight'
],
'quant_types'
:
[
'weight'
],
'quant_bits'
:
8
,
'quant_bits'
:
8
,
'op_types'
:[
'Conv2d'
,
'Linear'
]
'op_types'
:
[
'Conv2d'
,
'Linear'
]
},
{
},
{
'quant_types'
:
[
'output'
],
'quant_types'
:
[
'output'
],
'quant_bits'
:
8
,
'quant_bits'
:
8
,
'quant_start_step'
:
0
,
'quant_start_step'
:
0
,
'op_types'
:[
'ReLU'
]
'op_types'
:
[
'ReLU'
]
}]
}]
model
.
relu
=
torch
.
nn
.
ReLU
()
model
.
relu
=
torch
.
nn
.
ReLU
()
quantizer
=
torch_compressor
.
QAT_Quantizer
(
model
,
config_list
)
quantizer
=
torch_compressor
.
QAT_Quantizer
(
model
,
config_list
)
...
@@ -253,7 +255,7 @@ class CompressorTestCase(TestCase):
...
@@ -253,7 +255,7 @@ class CompressorTestCase(TestCase):
quantize_weight
=
quantizer
.
quantize_weight
(
weight
,
config_list
[
0
],
model
.
conv2
)
quantize_weight
=
quantizer
.
quantize_weight
(
weight
,
config_list
[
0
],
model
.
conv2
)
assert
math
.
isclose
(
model
.
conv2
.
scale
,
5
/
255
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
conv2
.
scale
,
5
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
zero_point
==
0
assert
model
.
conv2
.
zero_point
==
0
# range including 0
# range including 0
weight
=
torch
.
tensor
([[
-
1
,
2
],
[
3
,
5
]]).
float
()
weight
=
torch
.
tensor
([[
-
1
,
2
],
[
3
,
5
]]).
float
()
quantize_weight
=
quantizer
.
quantize_weight
(
weight
,
config_list
[
0
],
model
.
conv2
)
quantize_weight
=
quantizer
.
quantize_weight
(
weight
,
config_list
[
0
],
model
.
conv2
)
assert
math
.
isclose
(
model
.
conv2
.
scale
,
6
/
255
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
conv2
.
scale
,
6
/
255
,
abs_tol
=
eps
)
...
@@ -271,5 +273,6 @@ class CompressorTestCase(TestCase):
...
@@ -271,5 +273,6 @@ class CompressorTestCase(TestCase):
assert
math
.
isclose
(
model
.
relu
.
tracked_min_biased
,
0.002
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
tracked_min_biased
,
0.002
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
tracked_max_biased
,
0.00998
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
tracked_max_biased
,
0.00998
,
abs_tol
=
eps
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment