Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
543239c6
Unverified
Commit
543239c6
authored
Dec 12, 2019
by
SparkSnail
Committed by
GitHub
Dec 12, 2019
Browse files
Merge pull request #220 from microsoft/master
merge master
parents
32efaa36
659480f2
Changes
94
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
371 additions
and
409 deletions
+371
-409
src/nni_manager/training_service/pai/paiJobInfoCollector.ts
src/nni_manager/training_service/pai/paiJobInfoCollector.ts
+7
-12
src/nni_manager/training_service/pai/paiJobRestServer.ts
src/nni_manager/training_service/pai/paiJobRestServer.ts
+2
-3
src/nni_manager/training_service/pai/paiTrainingService.ts
src/nni_manager/training_service/pai/paiTrainingService.ts
+19
-33
src/nni_manager/training_service/pai/paiTrialConfig.ts
src/nni_manager/training_service/pai/paiTrialConfig.ts
+1
-1
src/nni_manager/training_service/remote_machine/gpuScheduler.ts
...i_manager/training_service/remote_machine/gpuScheduler.ts
+11
-14
src/nni_manager/training_service/remote_machine/remoteMachineData.ts
...ager/training_service/remote_machine/remoteMachineData.ts
+10
-11
src/nni_manager/training_service/remote_machine/remoteMachineJobRestServer.ts
...ning_service/remote_machine/remoteMachineJobRestServer.ts
+2
-3
src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts
...ng_service/remote_machine/remoteMachineTrainingService.ts
+24
-38
src/nni_manager/training_service/remote_machine/sshClientUtility.ts
...nager/training_service/remote_machine/sshClientUtility.ts
+37
-41
src/nni_manager/yarn.lock
src/nni_manager/yarn.lock
+8
-84
src/sdk/pynni/nni/batch_tuner/batch_tuner.py
src/sdk/pynni/nni/batch_tuner/batch_tuner.py
+9
-7
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
+211
-143
src/sdk/pynni/nni/msg_dispatcher_base.py
src/sdk/pynni/nni/msg_dispatcher_base.py
+19
-11
src/sdk/pynni/tests/test_compressor.py
src/sdk/pynni/tests/test_compressor.py
+11
-8
No files found.
src/nni_manager/training_service/pai/paiJobInfoCollector.ts
View file @
543239c6
...
...
@@ -3,7 +3,6 @@
'
use strict
'
;
// tslint:disable-next-line:no-implicit-dependencies
import
*
as
request
from
'
request
'
;
import
{
Deferred
}
from
'
ts-deferred
'
;
import
{
NNIError
,
NNIErrorNames
}
from
'
../../common/errors
'
;
...
...
@@ -16,10 +15,10 @@ import { PAITrialJobDetail } from './paiData';
* Collector PAI jobs info from PAI cluster, and update pai job status locally
*/
export
class
PAIJobInfoCollector
{
private
readonly
trialJobsMap
:
Map
<
string
,
PAITrialJobDetail
>
;
private
readonly
trialJobsMap
:
Map
<
string
,
PAITrialJobDetail
>
;
private
readonly
log
:
Logger
=
getLogger
();
private
readonly
statusesNeedToCheck
:
TrialJobStatus
[];
private
readonly
finalStatuses
:
TrialJobStatus
[];
private
readonly
statusesNeedToCheck
:
TrialJobStatus
[];
private
readonly
finalStatuses
:
TrialJobStatus
[];
constructor
(
jobMap
:
Map
<
string
,
PAITrialJobDetail
>
)
{
this
.
trialJobsMap
=
jobMap
;
...
...
@@ -27,12 +26,12 @@ export class PAIJobInfoCollector {
this
.
finalStatuses
=
[
'
SUCCEEDED
'
,
'
FAILED
'
,
'
USER_CANCELED
'
,
'
SYS_CANCELED
'
,
'
EARLY_STOPPED
'
];
}
public
async
retrieveTrialStatus
(
paiToken
?
:
string
,
paiClusterConfig
?:
PAIClusterConfig
)
:
Promise
<
void
>
{
public
async
retrieveTrialStatus
(
paiToken
?
:
string
,
paiClusterConfig
?:
PAIClusterConfig
):
Promise
<
void
>
{
if
(
paiClusterConfig
===
undefined
||
paiToken
===
undefined
)
{
return
Promise
.
resolve
();
}
const
updatePaiTrialJobs
:
Promise
<
void
>
[]
=
[];
const
updatePaiTrialJobs
:
Promise
<
void
>
[]
=
[];
for
(
const
[
trialJobId
,
paiTrialJob
]
of
this
.
trialJobsMap
)
{
if
(
paiTrialJob
===
undefined
)
{
throw
new
NNIError
(
NNIErrorNames
.
NOT_FOUND
,
`trial job id
${
trialJobId
}
not found`
);
...
...
@@ -43,9 +42,8 @@ export class PAIJobInfoCollector {
await
Promise
.
all
(
updatePaiTrialJobs
);
}
private
getSinglePAITrialJobInfo
(
paiTrialJob
:
PAITrialJobDetail
,
paiToken
:
string
,
paiClusterConfig
:
PAIClusterConfig
)
:
Promise
<
void
>
{
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
private
getSinglePAITrialJobInfo
(
paiTrialJob
:
PAITrialJobDetail
,
paiToken
:
string
,
paiClusterConfig
:
PAIClusterConfig
):
Promise
<
void
>
{
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
if
(
!
this
.
statusesNeedToCheck
.
includes
(
paiTrialJob
.
status
))
{
deferred
.
resolve
();
...
...
@@ -55,7 +53,6 @@ export class PAIJobInfoCollector {
// Rest call to get PAI job info and update status
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const
getJobInfoRequest
:
request
.
Options
=
{
// tslint:disable-next-line:no-http-string
uri
:
`http://
${
paiClusterConfig
.
host
}
/rest-server/api/v1/user/
${
paiClusterConfig
.
userName
}
/jobs/
${
paiTrialJob
.
paiJobName
}
`
,
method
:
'
GET
'
,
json
:
true
,
...
...
@@ -65,7 +62,6 @@ export class PAIJobInfoCollector {
}
};
// tslint:disable: no-unsafe-any no-any cyclomatic-complexity
//TODO : pass in request timeout param?
request
(
getJobInfoRequest
,
(
error
:
Error
,
response
:
request
.
Response
,
body
:
any
)
=>
{
if
((
error
!==
undefined
&&
error
!==
null
)
||
response
.
statusCode
>=
500
)
{
...
...
@@ -129,5 +125,4 @@ export class PAIJobInfoCollector {
return
deferred
.
promise
;
}
// tslint:enable: no-unsafe-any no-any
}
src/nni_manager/training_service/pai/paiJobRestServer.ts
View file @
543239c6
...
...
@@ -24,7 +24,7 @@ export class PAIJobRestServer extends ClusterJobRestServer {
private
parameterFileMetaList
:
ParameterFileMeta
[]
=
[];
@
Inject
private
readonly
paiTrainingService
:
PAITrainingService
;
private
readonly
paiTrainingService
:
PAITrainingService
;
/**
* constructor to provide NNIRestServer's own rest property, e.g. port
...
...
@@ -34,8 +34,7 @@ export class PAIJobRestServer extends ClusterJobRestServer {
this
.
paiTrainingService
=
component
.
get
(
PAITrainingService
);
}
// tslint:disable-next-line:no-any
protected
handleTrialMetrics
(
jobId
:
string
,
metrics
:
any
[])
:
void
{
protected
handleTrialMetrics
(
jobId
:
string
,
metrics
:
any
[]):
void
{
// Split metrics array into single metric, then emit
// Warning: If not split metrics into single ones, the behavior will be UNKNOWN
for
(
const
singleMetric
of
metrics
)
{
...
...
src/nni_manager/training_service/pai/paiTrainingService.ts
View file @
543239c6
...
...
@@ -3,17 +3,14 @@
'
use strict
'
;
import
*
as
cpp
from
'
child-process-promise
'
;
import
*
as
fs
from
'
fs
'
;
import
*
as
path
from
'
path
'
;
// tslint:disable-next-line:no-implicit-dependencies
import
*
as
request
from
'
request
'
;
import
*
as
component
from
'
../../common/component
'
;
import
{
EventEmitter
}
from
'
events
'
;
import
{
Deferred
}
from
'
ts-deferred
'
;
import
{
String
}
from
'
typescript-string-operations
'
;
import
{
MethodNotImplementedError
}
from
'
../../common/errors
'
;
import
{
getExperimentId
}
from
'
../../common/experimentStartupInfo
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
...
...
@@ -47,13 +44,12 @@ class PAITrainingService implements TrainingService {
private
paiClusterConfig
?:
PAIClusterConfig
;
private
readonly
jobQueue
:
string
[];
private
stopping
:
boolean
=
false
;
// tslint:disable-next-line:no-any
private
hdfsClient
:
any
;
private
paiToken
?
:
string
;
private
paiTokenUpdateTime
?:
number
;
private
readonly
paiTokenUpdateInterval
:
number
;
private
readonly
experimentId
!
:
string
;
private
readonly
paiJobCollector
:
PAIJobInfoCollector
;
private
readonly
experimentId
!
:
string
;
private
readonly
paiJobCollector
:
PAIJobInfoCollector
;
private
paiRestServerPort
?:
number
;
private
nniManagerIpConfig
?:
NNIManagerIpConfig
;
private
copyExpCodeDirPromise
?:
Promise
<
void
>
;
...
...
@@ -126,7 +122,7 @@ class PAITrainingService implements TrainingService {
if
(
this
.
paiClusterConfig
===
undefined
)
{
throw
new
Error
(
`paiClusterConfig not initialized!`
);
}
const
deferred
:
Deferred
<
PAITrialJobDetail
>
=
new
Deferred
<
PAITrialJobDetail
>
();
const
deferred
:
Deferred
<
PAITrialJobDetail
>
=
new
Deferred
<
PAITrialJobDetail
>
();
this
.
log
.
info
(
`submitTrialJob: form:
${
JSON
.
stringify
(
form
)}
`
);
...
...
@@ -137,7 +133,7 @@ class PAITrainingService implements TrainingService {
const
hdfsCodeDir
:
string
=
HDFSClientUtility
.
getHdfsTrialWorkDir
(
this
.
paiClusterConfig
.
userName
,
trialJobId
);
const
hdfsOutputDir
:
string
=
unixPathJoin
(
hdfsCodeDir
,
'
nnioutput
'
);
const
hdfsLogPath
:
string
=
String
.
Format
(
const
hdfsLogPath
:
string
=
String
.
Format
(
PAI_LOG_PATH_FORMAT
,
this
.
paiClusterConfig
.
host
,
hdfsOutputDir
...
...
@@ -173,10 +169,9 @@ class PAITrainingService implements TrainingService {
return
true
;
}
// tslint:disable:no-http-string
public
cancelTrialJob
(
trialJobId
:
string
,
isEarlyStopped
:
boolean
=
false
):
Promise
<
void
>
{
const
trialJobDetail
:
PAITrialJobDetail
|
undefined
=
this
.
trialJobsMap
.
get
(
trialJobId
);
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
trialJobDetail
:
PAITrialJobDetail
|
undefined
=
this
.
trialJobsMap
.
get
(
trialJobId
);
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
if
(
trialJobDetail
===
undefined
)
{
this
.
log
.
error
(
`cancelTrialJob: trial job id
${
trialJobId
}
not found`
);
...
...
@@ -205,7 +200,6 @@ class PAITrainingService implements TrainingService {
// Set trialjobDetail's early stopped field, to mark the job's cancellation source
trialJobDetail
.
isEarlyStopped
=
isEarlyStopped
;
// tslint:disable-next-line:no-any
request
(
stopJobRequest
,
(
error
:
Error
,
response
:
request
.
Response
,
body
:
any
)
=>
{
if
((
error
!==
undefined
&&
error
!==
null
)
||
response
.
statusCode
>=
400
)
{
this
.
log
.
error
(
`PAI Training service: stop trial
${
trialJobId
}
to PAI Cluster failed!`
);
...
...
@@ -219,10 +213,8 @@ class PAITrainingService implements TrainingService {
return
deferred
.
promise
;
}
// tslint:disable: no-unsafe-any no-any
// tslint:disable-next-line:max-func-body-length
public
async
setClusterMetadata
(
key
:
string
,
value
:
string
):
Promise
<
void
>
{
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
switch
(
key
)
{
case
TrialConfigMetadataKey
.
NNI_MANAGER_IP
:
...
...
@@ -300,10 +292,9 @@ class PAITrainingService implements TrainingService {
return
deferred
.
promise
;
}
// tslint:enable: no-unsafe-any
public
getClusterMetadata
(
key
:
string
):
Promise
<
string
>
{
const
deferred
:
Deferred
<
string
>
=
new
Deferred
<
string
>
();
const
deferred
:
Deferred
<
string
>
=
new
Deferred
<
string
>
();
deferred
.
resolve
();
...
...
@@ -314,14 +305,13 @@ class PAITrainingService implements TrainingService {
this
.
log
.
info
(
'
Stopping PAI training service...
'
);
this
.
stopping
=
true
;
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
restServer
:
PAIJobRestServer
=
component
.
get
(
PAIJobRestServer
);
try
{
await
restServer
.
stop
();
deferred
.
resolve
();
this
.
log
.
info
(
'
PAI Training service rest server stopped successfully.
'
);
}
catch
(
error
)
{
// tslint:disable-next-line: no-unsafe-any
this
.
log
.
error
(
`PAI Training service rest server stopped failed, error:
${
error
.
message
}
`
);
deferred
.
reject
(
error
);
}
...
...
@@ -329,13 +319,12 @@ class PAITrainingService implements TrainingService {
return
deferred
.
promise
;
}
public
get
MetricsEmitter
()
:
EventEmitter
{
public
get
MetricsEmitter
():
EventEmitter
{
return
this
.
metricsEmitter
;
}
// tslint:disable-next-line:max-func-body-length
private
async
submitTrialJobToPAI
(
trialJobId
:
string
):
Promise
<
boolean
>
{
const
deferred
:
Deferred
<
boolean
>
=
new
Deferred
<
boolean
>
();
const
deferred
:
Deferred
<
boolean
>
=
new
Deferred
<
boolean
>
();
const
trialJobDetail
:
PAITrialJobDetail
|
undefined
=
this
.
trialJobsMap
.
get
(
trialJobId
);
if
(
trialJobDetail
===
undefined
)
{
...
...
@@ -372,7 +361,7 @@ class PAITrainingService implements TrainingService {
//create tmp trial working folder locally.
await
execMkdir
(
trialLocalTempFolder
);
const
runScriptContent
:
string
=
CONTAINER_INSTALL_NNI_SHELL_FORMAT
;
const
runScriptContent
:
string
=
CONTAINER_INSTALL_NNI_SHELL_FORMAT
;
// Write NNI installation file to local tmp files
await
fs
.
promises
.
writeFile
(
path
.
join
(
trialLocalTempFolder
,
'
install_nni.sh
'
),
runScriptContent
,
{
encoding
:
'
utf8
'
});
...
...
@@ -385,10 +374,9 @@ class PAITrainingService implements TrainingService {
}
const
hdfsCodeDir
:
string
=
HDFSClientUtility
.
getHdfsTrialWorkDir
(
this
.
paiClusterConfig
.
userName
,
trialJobId
);
const
hdfsOutputDir
:
string
=
unixPathJoin
(
hdfsCodeDir
,
'
nnioutput
'
);
// tslint:disable-next-line: strict-boolean-expressions
const
nniManagerIp
:
string
=
this
.
nniManagerIpConfig
?
this
.
nniManagerIpConfig
.
nniManagerIp
:
getIPV4Address
();
const
version
:
string
=
this
.
versionCheck
?
await
getVersion
()
:
''
;
const
nniPaiTrialCommand
:
string
=
String
.
Format
(
const
nniPaiTrialCommand
:
string
=
String
.
Format
(
PAI_TRIAL_COMMAND_FORMAT
,
// PAI will copy job's codeDir into /root directory
`$PWD/
${
trialJobId
}
`
,
...
...
@@ -409,9 +397,8 @@ class PAITrainingService implements TrainingService {
)
.
replace
(
/
\r\n
|
\n
|
\r
/gm
,
''
);
// tslint:disable-next-line:no-console
this
.
log
.
info
(
`nniPAItrial command is
${
nniPaiTrialCommand
.
trim
()}
`
);
const
paiTaskRoles
:
PAITaskRole
[]
=
[
const
paiTaskRoles
:
PAITaskRole
[]
=
[
new
PAITaskRole
(
`nni_trail_
${
trialJobId
}
`
,
// Task role number
...
...
@@ -431,7 +418,7 @@ class PAITrainingService implements TrainingService {
)
];
const
paiJobConfig
:
PAIJobConfig
=
new
PAIJobConfig
(
const
paiJobConfig
:
PAIJobConfig
=
new
PAIJobConfig
(
// Job name
trialJobDetail
.
paiJobName
,
// Docker image
...
...
@@ -451,7 +438,7 @@ class PAITrainingService implements TrainingService {
await
HDFSClientUtility
.
copyDirectoryToHdfs
(
trialLocalTempFolder
,
hdfsCodeDir
,
this
.
hdfsClient
);
}
catch
(
error
)
{
this
.
log
.
error
(
`PAI Training service: copy
${
this
.
paiTrialConfig
.
codeDir
}
to HDFS
${
hdfsCodeDir
}
failed, error is
${
error
}
`
);
trialJobDetail
.
status
=
'
FAILED
'
;
trialJobDetail
.
status
=
'
FAILED
'
;
// eslint-disable-line require-atomic-updates
deferred
.
resolve
(
true
);
return
deferred
.
promise
;
...
...
@@ -469,10 +456,9 @@ class PAITrainingService implements TrainingService {
Authorization
:
`Bearer
${
this
.
paiToken
}
`
}
};
// tslint:disable:no-any no-unsafe-any
request
(
submitJobRequest
,
(
error
:
Error
,
response
:
request
.
Response
,
body
:
any
)
=>
{
if
((
error
!==
undefined
&&
error
!==
null
)
||
response
.
statusCode
>=
400
)
{
const
errorMessage
:
string
=
(
error
!==
undefined
&&
error
!==
null
)
?
error
.
message
:
const
errorMessage
:
string
=
(
error
!==
undefined
&&
error
!==
null
)
?
error
.
message
:
`Submit trial
${
trialJobId
}
failed, http code:
${
response
.
statusCode
}
, http body:
${
response
.
body
.
message
}
`
;
trialJobDetail
.
status
=
'
FAILED
'
;
deferred
.
resolve
(
true
);
...
...
@@ -527,7 +513,7 @@ class PAITrainingService implements TrainingService {
* Update pai token by the interval time or initialize the pai token
*/
private
async
updatePaiToken
():
Promise
<
void
>
{
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
currentTime
:
number
=
new
Date
().
getTime
();
//If pai token initialized and not reach the interval time, do not update
...
...
@@ -603,7 +589,7 @@ class PAITrainingService implements TrainingService {
}
private
postParameterFileMeta
(
parameterFileMeta
:
ParameterFileMeta
):
Promise
<
void
>
{
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
restServer
:
PAIJobRestServer
=
component
.
get
(
PAIJobRestServer
);
const
req
:
request
.
Options
=
{
uri
:
`
${
restServer
.
endPoint
}${
restServer
.
apiRootUrl
}
/parameter-file-meta`
,
...
...
src/nni_manager/training_service/pai/paiTrialConfig.ts
View file @
543239c6
...
...
@@ -15,7 +15,7 @@ export class PAITrialConfig extends TrialConfig {
public
readonly
dataDir
:
string
;
public
readonly
outputDir
:
string
;
constructor
(
command
:
string
,
codeDir
:
string
,
gpuNum
:
number
,
cpuNum
:
number
,
memoryMB
:
number
,
constructor
(
command
:
string
,
codeDir
:
string
,
gpuNum
:
number
,
cpuNum
:
number
,
memoryMB
:
number
,
image
:
string
,
dataDir
:
string
,
outputDir
:
string
)
{
super
(
command
,
codeDir
,
gpuNum
);
this
.
cpuNum
=
cpuNum
;
...
...
src/nni_manager/training_service/remote_machine/gpuScheduler.ts
View file @
543239c6
...
...
@@ -5,7 +5,6 @@
import
*
as
assert
from
'
assert
'
;
import
{
getLogger
,
Logger
}
from
'
../../common/log
'
;
import
{
TrialJobDetail
}
from
'
../../common/trainingService
'
;
import
{
randomSelect
}
from
'
../../common/utils
'
;
import
{
GPUInfo
}
from
'
../common/gpuData
'
;
import
{
...
...
@@ -19,7 +18,7 @@ type SCHEDULE_POLICY_NAME = 'random' | 'round-robin';
*/
export
class
GPUScheduler
{
private
readonly
machineSSHClientMap
:
Map
<
RemoteMachineMeta
,
SSHClientManager
>
;
private
readonly
machineSSHClientMap
:
Map
<
RemoteMachineMeta
,
SSHClientManager
>
;
private
readonly
log
:
Logger
=
getLogger
();
private
readonly
policyName
:
SCHEDULE_POLICY_NAME
=
'
round-robin
'
;
private
roundRobinIndex
:
number
=
0
;
...
...
@@ -29,7 +28,7 @@ export class GPUScheduler {
* Constructor
* @param machineSSHClientMap map from remote machine to sshClient
*/
constructor
(
machineSSHClientMap
:
Map
<
RemoteMachineMeta
,
SSHClientManager
>
)
{
constructor
(
machineSSHClientMap
:
Map
<
RemoteMachineMeta
,
SSHClientManager
>
)
{
assert
(
machineSSHClientMap
.
size
>
0
);
this
.
machineSSHClientMap
=
machineSSHClientMap
;
this
.
configuredRMs
=
Array
.
from
(
machineSSHClientMap
.
keys
());
...
...
@@ -39,7 +38,7 @@ export class GPUScheduler {
* Schedule a machine according to the constraints (requiredGPUNum)
* @param requiredGPUNum required GPU number
*/
public
scheduleMachine
(
requiredGPUNum
:
number
|
undefined
,
trialJobDetail
:
RemoteMachineTrialJobDetail
)
:
RemoteMachineScheduleResult
{
public
scheduleMachine
(
requiredGPUNum
:
number
|
undefined
,
trialJobDetail
:
RemoteMachineTrialJobDetail
):
RemoteMachineScheduleResult
{
if
(
requiredGPUNum
===
undefined
)
{
requiredGPUNum
=
0
;
}
...
...
@@ -48,7 +47,7 @@ export class GPUScheduler {
assert
(
allRMs
.
length
>
0
);
// Step 1: Check if required GPU number not exceeds the total GPU number in all machines
const
eligibleRM
:
RemoteMachineMeta
[]
=
allRMs
.
filter
((
rmMeta
:
RemoteMachineMeta
)
=>
const
eligibleRM
:
RemoteMachineMeta
[]
=
allRMs
.
filter
((
rmMeta
:
RemoteMachineMeta
)
=>
rmMeta
.
gpuSummary
===
undefined
||
requiredGPUNum
===
0
||
(
requiredGPUNum
!==
undefined
&&
rmMeta
.
gpuSummary
.
gpuCount
>=
requiredGPUNum
));
if
(
eligibleRM
.
length
===
0
)
{
// If the required gpu number exceeds the upper limit of all machine's GPU number
...
...
@@ -134,8 +133,8 @@ export class GPUScheduler {
* @param availableGPUMap available GPU resource filled by this detection
* @returns Available GPU number on this remote machine
*/
private
gpuResourceDetection
()
:
Map
<
RemoteMachineMeta
,
GPUInfo
[]
>
{
const
totalResourceMap
:
Map
<
RemoteMachineMeta
,
GPUInfo
[]
>
=
new
Map
<
RemoteMachineMeta
,
GPUInfo
[]
>
();
private
gpuResourceDetection
():
Map
<
RemoteMachineMeta
,
GPUInfo
[]
>
{
const
totalResourceMap
:
Map
<
RemoteMachineMeta
,
GPUInfo
[]
>
=
new
Map
<
RemoteMachineMeta
,
GPUInfo
[]
>
();
this
.
machineSSHClientMap
.
forEach
((
sshClientManager
:
SSHClientManager
,
rmMeta
:
RemoteMachineMeta
)
=>
{
// Assgin totoal GPU count as init available GPU number
if
(
rmMeta
.
gpuSummary
!==
undefined
)
{
...
...
@@ -149,7 +148,6 @@ export class GPUScheduler {
}
}
this
.
log
.
debug
(
`designated gpu indices:
${
designatedGpuIndices
}
`
);
// tslint:disable: strict-boolean-expressions
rmMeta
.
gpuSummary
.
gpuInfos
.
forEach
((
gpuInfo
:
GPUInfo
)
=>
{
// if the GPU has active process, OR be reserved by a job,
// or index not in gpuIndices configuration in machineList,
...
...
@@ -175,7 +173,6 @@ export class GPUScheduler {
return
totalResourceMap
;
}
// tslint:enable: strict-boolean-expressions
private
selectMachine
(
rmMetas
:
RemoteMachineMeta
[]):
RemoteMachineMeta
{
assert
(
rmMetas
!==
undefined
&&
rmMetas
.
length
>
0
);
...
...
@@ -224,7 +221,7 @@ export class GPUScheduler {
resultType
:
ScheduleResultType
.
SUCCEED
,
scheduleInfo
:
{
rmMeta
:
rmMeta
,
cuda
_v
isible
_d
evice
:
allocatedGPUs
cuda
V
isible
D
evice
:
allocatedGPUs
.
map
((
gpuInfo
:
GPUInfo
)
=>
{
return
gpuInfo
.
index
;
})
...
...
src/nni_manager/training_service/remote_machine/remoteMachineData.ts
View file @
543239c6
...
...
@@ -13,13 +13,13 @@ import { GPUInfo, GPUSummary } from '../common/gpuData';
* Metadata of remote machine for configuration and statuc query
*/
export
class
RemoteMachineMeta
{
public
readonly
ip
:
string
=
''
;
public
readonly
port
:
number
=
22
;
public
readonly
username
:
string
=
''
;
public
readonly
ip
:
string
=
''
;
public
readonly
port
:
number
=
22
;
public
readonly
username
:
string
=
''
;
public
readonly
passwd
:
string
=
''
;
public
readonly
sshKeyPath
?:
string
;
public
readonly
passphrase
?:
string
;
public
gpuSummary
:
GPUSummary
|
undefined
;
public
gpuSummary
:
GPUSummary
|
undefined
;
public
readonly
gpuIndices
?:
string
;
public
readonly
maxTrialNumPerGpu
?:
number
;
//TODO: initialize varialbe in constructor
...
...
@@ -43,11 +43,11 @@ export function parseGpuIndices(gpuIndices?: string): Set<number> | undefined {
* The execution result for command executed on remote machine
*/
export
class
RemoteCommandResult
{
public
readonly
stdout
:
string
;
public
readonly
stderr
:
string
;
public
readonly
exitCode
:
number
;
public
readonly
stdout
:
string
;
public
readonly
stderr
:
string
;
public
readonly
exitCode
:
number
;
constructor
(
stdout
:
string
,
stderr
:
string
,
exitCode
:
number
)
{
constructor
(
stdout
:
string
,
stderr
:
string
,
exitCode
:
number
)
{
this
.
stdout
=
stdout
;
this
.
stderr
=
stderr
;
this
.
exitCode
=
exitCode
;
...
...
@@ -186,7 +186,6 @@ export class SSHClientManager {
/**
* Create a new ssh connection client and initialize it
*/
// tslint:disable:non-literal-fs-path
private
initNewSSHClient
():
Promise
<
Client
>
{
const
deferred
:
Deferred
<
Client
>
=
new
Deferred
<
Client
>
();
const
conn
:
Client
=
new
Client
();
...
...
@@ -225,9 +224,9 @@ export class SSHClientManager {
}
}
export
type
RemoteMachineScheduleResult
=
{
scheduleInfo
:
RemoteMachineScheduleInfo
|
undefined
;
resultType
:
ScheduleResultType
};
export
type
RemoteMachineScheduleResult
=
{
scheduleInfo
:
RemoteMachineScheduleInfo
|
undefined
;
resultType
:
ScheduleResultType
};
export
type
RemoteMachineScheduleInfo
=
{
rmMeta
:
RemoteMachineMeta
;
cuda
_v
isible
_d
evice
:
string
};
export
type
RemoteMachineScheduleInfo
=
{
rmMeta
:
RemoteMachineMeta
;
cuda
V
isible
D
evice
:
string
};
export
enum
ScheduleResultType
{
// Schedule succeeded
...
...
src/nni_manager/training_service/remote_machine/remoteMachineJobRestServer.ts
View file @
543239c6
...
...
@@ -15,7 +15,7 @@ import { RemoteMachineTrainingService } from './remoteMachineTrainingService';
@
component
.
Singleton
export
class
RemoteMachineJobRestServer
extends
ClusterJobRestServer
{
@
Inject
private
readonly
remoteMachineTrainingService
:
RemoteMachineTrainingService
;
private
readonly
remoteMachineTrainingService
:
RemoteMachineTrainingService
;
/**
* constructor to provide NNIRestServer's own rest property, e.g. port
...
...
@@ -25,8 +25,7 @@ export class RemoteMachineJobRestServer extends ClusterJobRestServer {
this
.
remoteMachineTrainingService
=
component
.
get
(
RemoteMachineTrainingService
);
}
// tslint:disable-next-line:no-any
protected
handleTrialMetrics
(
jobId
:
string
,
metrics
:
any
[])
:
void
{
protected
handleTrialMetrics
(
jobId
:
string
,
metrics
:
any
[]):
void
{
// Split metrics array into single metric, then emit
// Warning: If not split metrics into single ones, the behavior will be UNKNOWNls
for
(
const
singleMetric
of
metrics
)
{
...
...
src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts
View file @
543239c6
...
...
@@ -4,12 +4,10 @@
'
use strict
'
;
import
*
as
assert
from
'
assert
'
;
import
*
as
cpp
from
'
child-process-promise
'
;
import
{
EventEmitter
}
from
'
events
'
;
import
*
as
fs
from
'
fs
'
;
import
*
as
os
from
'
os
'
;
import
*
as
path
from
'
path
'
;
import
{
Client
,
ConnectConfig
}
from
'
ssh2
'
;
import
{
Client
}
from
'
ssh2
'
;
import
{
Deferred
}
from
'
ts-deferred
'
;
import
{
String
}
from
'
typescript-string-operations
'
;
import
*
as
component
from
'
../../common/component
'
;
...
...
@@ -29,12 +27,12 @@ import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import
{
GPUSummary
}
from
'
../common/gpuData
'
;
import
{
TrialConfig
}
from
'
../common/trialConfig
'
;
import
{
TrialConfigMetadataKey
}
from
'
../common/trialConfigMetadataKey
'
;
import
{
execCopydir
,
execMkdir
,
execRemove
,
validateCodeDir
,
getGpuMetricsCollectorBashScriptContent
}
from
'
../common/util
'
;
import
{
execCopydir
,
execMkdir
,
validateCodeDir
,
getGpuMetricsCollectorBashScriptContent
}
from
'
../common/util
'
;
import
{
GPUScheduler
}
from
'
./gpuScheduler
'
;
import
{
HOST_JOB_SHELL_FORMAT
,
RemoteCommandResult
,
REMOTEMACHINE_TRIAL_COMMAND_FORMAT
,
RemoteMachineMeta
,
RemoteCommandResult
,
REMOTEMACHINE_TRIAL_COMMAND_FORMAT
,
RemoteMachineMeta
,
RemoteMachineScheduleInfo
,
RemoteMachineScheduleResult
,
RemoteMachineTrialJobDetail
,
ScheduleResultType
,
SSHClient
,
SSHClientManager
ScheduleResultType
,
SSHClientManager
}
from
'
./remoteMachineData
'
;
import
{
RemoteMachineJobRestServer
}
from
'
./remoteMachineJobRestServer
'
;
import
{
SSHClientUtility
}
from
'
./sshClientUtility
'
;
...
...
@@ -93,7 +91,7 @@ class RemoteMachineTrainingService implements TrainingService {
while
(
this
.
jobQueue
.
length
>
0
)
{
this
.
updateGpuReservation
();
const
trialJobId
:
string
=
this
.
jobQueue
[
0
];
const
prepareResult
:
boolean
=
await
this
.
prepareTrialJob
(
trialJobId
);
const
prepareResult
:
boolean
=
await
this
.
prepareTrialJob
(
trialJobId
);
if
(
prepareResult
)
{
// Remove trial job with trialJobId from job queue
this
.
jobQueue
.
shift
();
...
...
@@ -208,7 +206,6 @@ class RemoteMachineTrainingService implements TrainingService {
* Submit trial job
* @param form trial job description form
*/
// tslint:disable-next-line:informative-docs
public
async
submitTrialJob
(
form
:
TrialJobApplicationForm
):
Promise
<
TrialJobDetail
>
{
if
(
this
.
trialConfig
===
undefined
)
{
throw
new
Error
(
'
trial config is not initialized
'
);
...
...
@@ -241,12 +238,7 @@ class RemoteMachineTrainingService implements TrainingService {
if
(
trialJobDetail
===
undefined
)
{
throw
new
Error
(
`updateTrialJob failed:
${
trialJobId
}
not found`
);
}
const
rmMeta
:
RemoteMachineMeta
|
undefined
=
(
<
RemoteMachineTrialJobDetail
>
trialJobDetail
).
rmMeta
;
if
(
rmMeta
!==
undefined
)
{
await
this
.
writeParameterFile
(
trialJobId
,
form
.
hyperParameters
,
rmMeta
);
}
else
{
throw
new
Error
(
`updateTrialJob failed:
${
trialJobId
}
rmMeta not found`
);
}
await
this
.
writeParameterFile
(
trialJobId
,
form
.
hyperParameters
);
return
trialJobDetail
;
}
...
...
@@ -262,7 +254,6 @@ class RemoteMachineTrainingService implements TrainingService {
* Cancel trial job
* @param trialJobId ID of trial job
*/
// tslint:disable:informative-docs no-unsafe-any
public
async
cancelTrialJob
(
trialJobId
:
string
,
isEarlyStopped
:
boolean
=
false
):
Promise
<
void
>
{
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
trialJob
:
RemoteMachineTrialJobDetail
|
undefined
=
this
.
trialJobsMap
.
get
(
trialJobId
);
...
...
@@ -272,7 +263,7 @@ class RemoteMachineTrainingService implements TrainingService {
}
// Remove the job with trialJobId from job queue
const
index
:
number
=
this
.
jobQueue
.
indexOf
(
trialJobId
);
const
index
:
number
=
this
.
jobQueue
.
indexOf
(
trialJobId
);
if
(
index
>=
0
)
{
this
.
jobQueue
.
splice
(
index
,
1
);
}
...
...
@@ -319,14 +310,13 @@ class RemoteMachineTrainingService implements TrainingService {
await
this
.
setupConnections
(
value
);
this
.
gpuScheduler
=
new
GPUScheduler
(
this
.
machineSSHClientMap
);
break
;
case
TrialConfigMetadataKey
.
TRIAL_CONFIG
:
case
TrialConfigMetadataKey
.
TRIAL_CONFIG
:
{
const
remoteMachineTrailConfig
:
TrialConfig
=
<
TrialConfig
>
JSON
.
parse
(
value
);
// Parse trial config failed, throw Error
if
(
remoteMachineTrailConfig
===
undefined
)
{
throw
new
Error
(
'
trial config parsed failed
'
);
}
// codeDir is not a valid directory, throw Error
// tslint:disable-next-line:non-literal-fs-path
if
(
!
fs
.
lstatSync
(
remoteMachineTrailConfig
.
codeDir
)
.
isDirectory
())
{
throw
new
Error
(
`codeDir
${
remoteMachineTrailConfig
.
codeDir
}
is not a directory`
);
...
...
@@ -343,6 +333,7 @@ class RemoteMachineTrainingService implements TrainingService {
this
.
trialConfig
=
remoteMachineTrailConfig
;
break
;
}
case
TrialConfigMetadataKey
.
MULTI_PHASE
:
this
.
isMultiPhase
=
(
value
===
'
true
'
||
value
===
'
True
'
);
break
;
...
...
@@ -444,7 +435,6 @@ class RemoteMachineTrainingService implements TrainingService {
await
SSHClientUtility
.
remoteExeCommand
(
`chmod 777
${
nniRootDir
}
${
nniRootDir
}
/*
${
nniRootDir
}
/scripts/*`
,
conn
);
//Begin to execute gpu_metrics_collection scripts
// tslint:disable-next-line: no-floating-promises
const
script
=
getGpuMetricsCollectorBashScriptContent
(
remoteGpuScriptCollectorDir
);
SSHClientUtility
.
remoteExeCommand
(
`bash -c '
${
script
}
'`
,
conn
);
...
...
@@ -464,7 +454,7 @@ class RemoteMachineTrainingService implements TrainingService {
}
private
async
prepareTrialJob
(
trialJobId
:
string
):
Promise
<
boolean
>
{
const
deferred
:
Deferred
<
boolean
>
=
new
Deferred
<
boolean
>
();
const
deferred
:
Deferred
<
boolean
>
=
new
Deferred
<
boolean
>
();
if
(
this
.
trialConfig
===
undefined
)
{
throw
new
Error
(
'
trial config is not initialized
'
);
...
...
@@ -485,13 +475,13 @@ class RemoteMachineTrainingService implements TrainingService {
// get an ssh client from scheduler
const
rmScheduleResult
:
RemoteMachineScheduleResult
=
this
.
gpuScheduler
.
scheduleMachine
(
this
.
trialConfig
.
gpuNum
,
trialJobDetail
);
if
(
rmScheduleResult
.
resultType
===
ScheduleResultType
.
REQUIRE_EXCEED_TOTAL
)
{
const
errorMessage
:
string
=
`Required GPU number
${
this
.
trialConfig
.
gpuNum
}
is too large, no machine can meet`
;
const
errorMessage
:
string
=
`Required GPU number
${
this
.
trialConfig
.
gpuNum
}
is too large, no machine can meet`
;
this
.
log
.
error
(
errorMessage
);
deferred
.
reject
();
throw
new
NNIError
(
NNIErrorNames
.
RESOURCE_NOT_AVAILABLE
,
errorMessage
);
}
else
if
(
rmScheduleResult
.
resultType
===
ScheduleResultType
.
SUCCEED
&&
rmScheduleResult
.
scheduleInfo
!==
undefined
)
{
const
rmScheduleInfo
:
RemoteMachineScheduleInfo
=
rmScheduleResult
.
scheduleInfo
;
const
rmScheduleInfo
:
RemoteMachineScheduleInfo
=
rmScheduleResult
.
scheduleInfo
;
const
trialWorkingFolder
:
string
=
unixPathJoin
(
this
.
remoteExpRootDir
,
'
trials
'
,
trialJobId
);
trialJobDetail
.
rmMeta
=
rmScheduleInfo
.
rmMeta
;
...
...
@@ -521,7 +511,7 @@ class RemoteMachineTrainingService implements TrainingService {
if
(
this
.
trialConfig
===
undefined
)
{
throw
new
Error
(
'
trial config is not initialized
'
);
}
const
cuda
_v
isible
_d
evice
:
string
=
rmScheduleInfo
.
cuda
_v
isible
_d
evice
;
const
cuda
V
isible
D
evice
:
string
=
rmScheduleInfo
.
cuda
V
isible
D
evice
;
const
sshClient
:
Client
|
undefined
=
this
.
trialSSHClientMap
.
get
(
trialJobId
);
if
(
sshClient
===
undefined
)
{
assert
(
false
,
'
sshClient is undefined.
'
);
...
...
@@ -543,19 +533,18 @@ class RemoteMachineTrainingService implements TrainingService {
// See definition in remoteMachineData.ts
let
command
:
string
;
// Set CUDA_VISIBLE_DEVICES environment variable based on cuda
_v
isible
_d
evice
// If no valid cuda
_v
isible
_d
evice is defined, set CUDA_VISIBLE_DEVICES to empty string to hide GPU device
// Set CUDA_VISIBLE_DEVICES environment variable based on cuda
V
isible
D
evice
// If no valid cuda
V
isible
D
evice is defined, set CUDA_VISIBLE_DEVICES to empty string to hide GPU device
// If gpuNum is undefined, will not set CUDA_VISIBLE_DEVICES in script
if
(
this
.
trialConfig
.
gpuNum
===
undefined
)
{
command
=
this
.
trialConfig
.
command
;
}
else
{
if
(
typeof
cuda
_v
isible
_d
evice
===
'
string
'
&&
cuda
_v
isible
_d
evice
.
length
>
0
)
{
command
=
`CUDA_VISIBLE_DEVICES=
${
cuda
_v
isible
_d
evice
}
${
this
.
trialConfig
.
command
}
`
;
if
(
typeof
cuda
V
isible
D
evice
===
'
string
'
&&
cuda
V
isible
D
evice
.
length
>
0
)
{
command
=
`CUDA_VISIBLE_DEVICES=
${
cuda
V
isible
D
evice
}
${
this
.
trialConfig
.
command
}
`
;
}
else
{
command
=
`CUDA_VISIBLE_DEVICES=" "
${
this
.
trialConfig
.
command
}
`
;
}
}
// tslint:disable-next-line: strict-boolean-expressions
const
nniManagerIp
:
string
=
this
.
nniManagerIpConfig
?
this
.
nniManagerIpConfig
.
nniManagerIp
:
getIPV4Address
();
if
(
this
.
remoteRestServerPort
===
undefined
)
{
const
restServer
:
RemoteMachineJobRestServer
=
component
.
get
(
RemoteMachineJobRestServer
);
...
...
@@ -584,16 +573,15 @@ class RemoteMachineTrainingService implements TrainingService {
//create tmp trial working folder locally.
await
execCopydir
(
this
.
trialConfig
.
codeDir
,
trialLocalTempFolder
);
const
installScriptContent
:
string
=
CONTAINER_INSTALL_NNI_SHELL_FORMAT
;
const
installScriptContent
:
string
=
CONTAINER_INSTALL_NNI_SHELL_FORMAT
;
// Write NNI installation file to local tmp files
await
fs
.
promises
.
writeFile
(
path
.
join
(
trialLocalTempFolder
,
'
install_nni.sh
'
),
installScriptContent
,
{
encoding
:
'
utf8
'
});
// Write file content ( run.sh and parameter.cfg ) to local tmp files
await
fs
.
promises
.
writeFile
(
path
.
join
(
trialLocalTempFolder
,
'
run.sh
'
),
runScriptTrialContent
,
{
encoding
:
'
utf8
'
});
await
this
.
writeParameterFile
(
trialJobId
,
form
.
hyperParameters
,
rmScheduleInfo
.
rmMeta
);
await
this
.
writeParameterFile
(
trialJobId
,
form
.
hyperParameters
);
// Copy files in codeDir to remote working directory
await
SSHClientUtility
.
copyDirectoryToRemote
(
trialLocalTempFolder
,
trialWorkingFolder
,
sshClient
,
this
.
remoteOS
);
// Execute command in remote machine
// tslint:disable-next-line: no-floating-promises
SSHClientUtility
.
remoteExeCommand
(
`bash
${
unixPathJoin
(
trialWorkingFolder
,
'
run.sh
'
)}
`
,
sshClient
);
}
...
...
@@ -610,6 +598,7 @@ class RemoteMachineTrainingService implements TrainingService {
const
deferred
:
Deferred
<
TrialJobDetail
>
=
new
Deferred
<
TrialJobDetail
>
();
const
jobpidPath
:
string
=
this
.
getJobPidPath
(
trialJob
.
id
);
const
trialReturnCodeFilePath
:
string
=
unixPathJoin
(
this
.
remoteExpRootDir
,
'
trials
'
,
trialJob
.
id
,
'
.nni
'
,
'
code
'
);
/* eslint-disable require-atomic-updates */
try
{
const
killResult
:
number
=
(
await
SSHClientUtility
.
remoteExeCommand
(
`kill -0
\`
cat
${
jobpidPath
}
\`
`
,
sshClient
)).
exitCode
;
// if the process of jobpid is not alive any more
...
...
@@ -646,7 +635,7 @@ class RemoteMachineTrainingService implements TrainingService {
deferred
.
resolve
(
trialJob
);
}
}
/* eslint-enable require-atomic-updates */
return
deferred
.
promise
;
}
...
...
@@ -662,7 +651,7 @@ class RemoteMachineTrainingService implements TrainingService {
return
unixPathJoin
(
getRemoteTmpDir
(
this
.
remoteOS
),
'
nni
'
,
'
experiments
'
,
getExperimentId
());
}
public
get
MetricsEmitter
()
:
EventEmitter
{
public
get
MetricsEmitter
():
EventEmitter
{
return
this
.
metricsEmitter
;
}
...
...
@@ -672,13 +661,10 @@ class RemoteMachineTrainingService implements TrainingService {
throw
new
NNIError
(
NNIErrorNames
.
INVALID_JOB_DETAIL
,
`Invalid job detail information for trial job
${
jobId
}
`
);
}
let
jobpidPath
:
string
;
jobpidPath
=
unixPathJoin
(
trialJobDetail
.
workingDirectory
,
'
.nni
'
,
'
jobpid
'
);
return
jobpidPath
;
return
unixPathJoin
(
trialJobDetail
.
workingDirectory
,
'
.nni
'
,
'
jobpid
'
);
}
private
async
writeParameterFile
(
trialJobId
:
string
,
hyperParameters
:
HyperParameters
,
rmMeta
:
RemoteMachineMeta
):
Promise
<
void
>
{
private
async
writeParameterFile
(
trialJobId
:
string
,
hyperParameters
:
HyperParameters
):
Promise
<
void
>
{
const
sshClient
:
Client
|
undefined
=
this
.
trialSSHClientMap
.
get
(
trialJobId
);
if
(
sshClient
===
undefined
)
{
throw
new
Error
(
'
sshClient is undefined.
'
);
...
...
src/nni_manager/training_service/remote_machine/sshClientUtility.ts
View file @
543239c6
...
...
@@ -4,7 +4,6 @@
'
use strict
'
;
import
*
as
assert
from
'
assert
'
;
import
*
as
cpp
from
'
child-process-promise
'
;
import
*
as
os
from
'
os
'
;
import
*
as
path
from
'
path
'
;
import
{
Client
,
ClientChannel
,
SFTPWrapper
}
from
'
ssh2
'
;
...
...
@@ -22,44 +21,18 @@ import { RemoteCommandResult } from './remoteMachineData';
*
*/
export
namespace
SSHClientUtility
{
/**
* Copy files and directories in local directory recursively to remote directory
* @param localDirectory local diretory
* @param remoteDirectory remote directory
* @param sshClient SSH client
*/
export
async
function
copyDirectoryToRemote
(
localDirectory
:
string
,
remoteDirectory
:
string
,
sshClient
:
Client
,
remoteOS
:
string
)
:
Promise
<
void
>
{
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
tmpTarName
:
string
=
`
${
uniqueString
(
10
)}
.tar.gz`
;
const
localTarPath
:
string
=
path
.
join
(
os
.
tmpdir
(),
tmpTarName
);
const
remoteTarPath
:
string
=
unixPathJoin
(
getRemoteTmpDir
(
remoteOS
),
tmpTarName
);
// Compress files in local directory to experiment root directory
await
tarAdd
(
localTarPath
,
localDirectory
);
// Copy the compressed file to remoteDirectory and delete it
await
copyFileToRemote
(
localTarPath
,
remoteTarPath
,
sshClient
);
await
execRemove
(
localTarPath
);
// Decompress the remote compressed file in and delete it
await
remoteExeCommand
(
`tar -oxzf
${
remoteTarPath
}
-C
${
remoteDirectory
}
`
,
sshClient
);
await
remoteExeCommand
(
`rm
${
remoteTarPath
}
`
,
sshClient
);
deferred
.
resolve
();
return
deferred
.
promise
;
}
/**
* Copy local file to remote path
* @param localFilePath the path of local file
* @param remoteFilePath the target path in remote machine
* @param sshClient SSH Client
*/
export
function
copyFileToRemote
(
localFilePath
:
string
,
remoteFilePath
:
string
,
sshClient
:
Client
)
:
Promise
<
boolean
>
{
export
function
copyFileToRemote
(
localFilePath
:
string
,
remoteFilePath
:
string
,
sshClient
:
Client
):
Promise
<
boolean
>
{
const
log
:
Logger
=
getLogger
();
log
.
debug
(
`copyFileToRemote: localFilePath:
${
localFilePath
}
, remoteFilePath:
${
remoteFilePath
}
`
);
assert
(
sshClient
!==
undefined
);
const
deferred
:
Deferred
<
boolean
>
=
new
Deferred
<
boolean
>
();
sshClient
.
sftp
((
err
:
Error
,
sftp
:
SFTPWrapper
)
=>
{
sshClient
.
sftp
((
err
:
Error
,
sftp
:
SFTPWrapper
)
=>
{
if
(
err
!==
undefined
&&
err
!==
null
)
{
log
.
error
(
`copyFileToRemote:
${
err
.
message
}
,
${
localFilePath
}
,
${
remoteFilePath
}
`
);
deferred
.
reject
(
err
);
...
...
@@ -67,7 +40,7 @@ export namespace SSHClientUtility {
return
;
}
assert
(
sftp
!==
undefined
);
sftp
.
fastPut
(
localFilePath
,
remoteFilePath
,
(
fastPutErr
:
Error
)
=>
{
sftp
.
fastPut
(
localFilePath
,
remoteFilePath
,
(
fastPutErr
:
Error
)
=>
{
sftp
.
end
();
if
(
fastPutErr
!==
undefined
&&
fastPutErr
!==
null
)
{
deferred
.
reject
(
fastPutErr
);
...
...
@@ -85,16 +58,15 @@ export namespace SSHClientUtility {
* @param command the command to execute remotely
* @param client SSH Client
*/
// tslint:disable:no-unsafe-any no-any
export
function
remoteExeCommand
(
command
:
string
,
client
:
Client
):
Promise
<
RemoteCommandResult
>
{
export
function
remoteExeCommand
(
command
:
string
,
client
:
Client
):
Promise
<
RemoteCommandResult
>
{
const
log
:
Logger
=
getLogger
();
log
.
debug
(
`remoteExeCommand: command: [
${
command
}
]`
);
const
deferred
:
Deferred
<
RemoteCommandResult
>
=
new
Deferred
<
RemoteCommandResult
>
();
const
deferred
:
Deferred
<
RemoteCommandResult
>
=
new
Deferred
<
RemoteCommandResult
>
();
let
stdout
:
string
=
''
;
let
stderr
:
string
=
''
;
let
exitCode
:
number
;
let
exitCode
:
number
;
client
.
exec
(
command
,
(
err
:
Error
,
channel
:
ClientChannel
)
=>
{
client
.
exec
(
command
,
(
err
:
Error
,
channel
:
ClientChannel
)
=>
{
if
(
err
!==
undefined
&&
err
!==
null
)
{
log
.
error
(
`remoteExeCommand:
${
err
.
message
}
`
);
deferred
.
reject
(
err
);
...
...
@@ -102,14 +74,14 @@ export namespace SSHClientUtility {
return
;
}
channel
.
on
(
'
data
'
,
(
data
:
any
,
dataStderr
:
any
)
=>
{
channel
.
on
(
'
data
'
,
(
data
:
any
,
dataStderr
:
any
)
=>
{
if
(
dataStderr
!==
undefined
&&
dataStderr
!==
null
)
{
stderr
+=
data
.
toString
();
}
else
{
stdout
+=
data
.
toString
();
}
})
.
on
(
'
exit
'
,
(
code
:
any
,
signal
:
any
)
=>
{
.
on
(
'
exit
'
,
(
code
:
any
,
signal
:
any
)
=>
{
exitCode
=
<
number
>
code
;
deferred
.
resolve
({
stdout
:
stdout
,
...
...
@@ -122,9 +94,34 @@ export namespace SSHClientUtility {
return
deferred
.
promise
;
}
/**
* Copy files and directories in local directory recursively to remote directory
* @param localDirectory local diretory
* @param remoteDirectory remote directory
* @param sshClient SSH client
*/
export
async
function
copyDirectoryToRemote
(
localDirectory
:
string
,
remoteDirectory
:
string
,
sshClient
:
Client
,
remoteOS
:
string
):
Promise
<
void
>
{
const
deferred
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
const
tmpTarName
:
string
=
`
${
uniqueString
(
10
)}
.tar.gz`
;
const
localTarPath
:
string
=
path
.
join
(
os
.
tmpdir
(),
tmpTarName
);
const
remoteTarPath
:
string
=
unixPathJoin
(
getRemoteTmpDir
(
remoteOS
),
tmpTarName
);
// Compress files in local directory to experiment root directory
await
tarAdd
(
localTarPath
,
localDirectory
);
// Copy the compressed file to remoteDirectory and delete it
await
copyFileToRemote
(
localTarPath
,
remoteTarPath
,
sshClient
);
await
execRemove
(
localTarPath
);
// Decompress the remote compressed file in and delete it
await
remoteExeCommand
(
`tar -oxzf
${
remoteTarPath
}
-C
${
remoteDirectory
}
`
,
sshClient
);
await
remoteExeCommand
(
`rm
${
remoteTarPath
}
`
,
sshClient
);
deferred
.
resolve
();
return
deferred
.
promise
;
}
export
function
getRemoteFileContent
(
filePath
:
string
,
sshClient
:
Client
):
Promise
<
string
>
{
const
deferred
:
Deferred
<
string
>
=
new
Deferred
<
string
>
();
sshClient
.
sftp
((
err
:
Error
,
sftp
:
SFTPWrapper
)
=>
{
sshClient
.
sftp
((
err
:
Error
,
sftp
:
SFTPWrapper
)
=>
{
if
(
err
!==
undefined
&&
err
!==
null
)
{
getLogger
()
.
error
(
`getRemoteFileContent:
${
err
.
message
}
`
);
...
...
@@ -133,10 +130,10 @@ export namespace SSHClientUtility {
return
;
}
try
{
const
sftpStream
:
stream
.
Readable
=
sftp
.
createReadStream
(
filePath
);
const
sftpStream
:
stream
.
Readable
=
sftp
.
createReadStream
(
filePath
);
let
dataBuffer
:
string
=
''
;
sftpStream
.
on
(
'
data
'
,
(
data
:
Buffer
|
string
)
=>
{
sftpStream
.
on
(
'
data
'
,
(
data
:
Buffer
|
string
)
=>
{
dataBuffer
+=
data
;
})
.
on
(
'
error
'
,
(
streamErr
:
Error
)
=>
{
...
...
@@ -158,5 +155,4 @@ export namespace SSHClientUtility {
return
deferred
.
promise
;
}
// tslint:enable:no-unsafe-any no-any
}
src/nni_manager/yarn.lock
View file @
543239c6
...
...
@@ -703,7 +703,7 @@ buffer-stream-reader@^0.1.1:
version "0.1.1"
resolved "https://registry.yarnpkg.com/buffer-stream-reader/-/buffer-stream-reader-0.1.1.tgz#ca8bf93631deedd8b8f8c3bb44991cc30951e259"
builtin-modules@^1.0.0
, builtin-modules@^1.1.1
:
builtin-modules@^1.0.0:
version "1.1.1"
resolved "https://registry.yarnpkg.com/builtin-modules/-/builtin-modules-1.1.1.tgz#270f076c5a72c02f5b65a47df94c5fe3a278892f"
...
...
@@ -841,7 +841,7 @@ chalk@^1.0.0:
strip-ansi "^3.0.0"
supports-color "^2.0.0"
chalk@^2.0.0
, chalk@^2.3.0
:
chalk@^2.0.0:
version "2.4.1"
resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.4.1.tgz#18c49ab16a037b6eb0152cc83e3471338215b66e"
dependencies:
...
...
@@ -971,10 +971,6 @@ commander@2.15.1:
version "2.15.1"
resolved "https://registry.yarnpkg.com/commander/-/commander-2.15.1.tgz#df46e867d0fc2aec66a34662b406a9ccafff5b0f"
commander@^2.12.1:
version "2.16.0"
resolved "https://registry.yarnpkg.com/commander/-/commander-2.16.0.tgz#f16390593996ceb4f3eeb020b31d78528f7f8a50"
commander@~2.17.1:
version "2.17.1"
resolved "https://registry.yarnpkg.com/commander/-/commander-2.17.1.tgz#bd77ab7de6de94205ceacc72f1716d29f20a77bf"
...
...
@@ -1134,7 +1130,7 @@ debug@^4.0.1, debug@^4.1.0, debug@^4.1.1:
dependencies:
ms "^2.1.1"
debuglog@*,
debuglog@^1.0.1:
debuglog@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/debuglog/-/debuglog-1.0.1.tgz#aa24ffb9ac3df9a2351837cfb2d279360cd78492"
...
...
@@ -1217,7 +1213,7 @@ dezalgo@^1.0.0, dezalgo@~1.0.3:
asap "^2.0.0"
wrappy "1"
diff@3.5.0, diff@^3.1.0
, diff@^3.2.0
:
diff@3.5.0, diff@^3.1.0:
version "3.5.0"
resolved "https://registry.yarnpkg.com/diff/-/diff-3.5.0.tgz#800c0dd1e0a8bfbc95835c202ad220fe317e5a12"
...
...
@@ -2080,7 +2076,7 @@ import-lazy@^2.1.0:
version "2.1.0"
resolved "https://registry.yarnpkg.com/import-lazy/-/import-lazy-2.1.0.tgz#05698e3d45c88e8d7e9d92cb0584e77f096f3e43"
imurmurhash@*,
imurmurhash@^0.1.4:
imurmurhash@^0.1.4:
version "0.1.4"
resolved "https://registry.yarnpkg.com/imurmurhash/-/imurmurhash-0.1.4.tgz#9218b9b2b928a238b13dc4fb6b6d576f231453ea"
...
...
@@ -2519,10 +2515,6 @@ lockfile@~1.0.3:
dependencies:
signal-exit "^3.0.2"
lodash._baseindexof@*:
version "3.1.0"
resolved "https://registry.yarnpkg.com/lodash._baseindexof/-/lodash._baseindexof-3.1.0.tgz#fe52b53a1c6761e42618d654e4a25789ed61822c"
lodash._baseuniq@~4.6.0:
version "4.6.0"
resolved "https://registry.yarnpkg.com/lodash._baseuniq/-/lodash._baseuniq-4.6.0.tgz#0ebb44e456814af7905c6212fa2c9b2d51b841e8"
...
...
@@ -2530,28 +2522,10 @@ lodash._baseuniq@~4.6.0:
lodash._createset "~4.0.0"
lodash._root "~3.0.0"
lodash._bindcallback@*:
version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._bindcallback/-/lodash._bindcallback-3.0.1.tgz#e531c27644cf8b57a99e17ed95b35c748789392e"
lodash._cacheindexof@*:
version "3.0.2"
resolved "https://registry.yarnpkg.com/lodash._cacheindexof/-/lodash._cacheindexof-3.0.2.tgz#3dc69ac82498d2ee5e3ce56091bafd2adc7bde92"
lodash._createcache@*:
version "3.1.2"
resolved "https://registry.yarnpkg.com/lodash._createcache/-/lodash._createcache-3.1.2.tgz#56d6a064017625e79ebca6b8018e17440bdcf093"
dependencies:
lodash._getnative "^3.0.0"
lodash._createset@~4.0.0:
version "4.0.3"
resolved "https://registry.yarnpkg.com/lodash._createset/-/lodash._createset-4.0.3.tgz#0f4659fbb09d75194fa9e2b88a6644d363c9fe26"
lodash._getnative@*, lodash._getnative@^3.0.0:
version "3.9.1"
resolved "https://registry.yarnpkg.com/lodash._getnative/-/lodash._getnative-3.9.1.tgz#570bc7dede46d61cdcde687d65d3eecbaa3aaff5"
lodash._root@~3.0.0:
version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._root/-/lodash._root-3.0.1.tgz#fba1c4524c19ee9a5f8136b4609f017cf4ded692"
...
...
@@ -2600,10 +2574,6 @@ lodash.pick@^4.4.0:
version "4.4.0"
resolved "https://registry.yarnpkg.com/lodash.pick/-/lodash.pick-4.4.0.tgz#52f05610fff9ded422611441ed1fc123a03001b3"
lodash.restparam@*:
version "3.6.1"
resolved "https://registry.yarnpkg.com/lodash.restparam/-/lodash.restparam-3.6.1.tgz#936a4e309ef330a7645ed4145986c85ae5b20805"
lodash.unescape@4.0.1:
version "4.0.1"
resolved "https://registry.yarnpkg.com/lodash.unescape/-/lodash.unescape-4.0.1.tgz#bf2249886ce514cda112fae9218cdc065211fc9c"
...
...
@@ -3519,10 +3489,6 @@ path-key@^2.0.0, path-key@^2.0.1:
version "2.0.1"
resolved "https://registry.yarnpkg.com/path-key/-/path-key-2.0.1.tgz#411cadb574c5a140d3a4b1910d40d80cc9f40b40"
path-parse@^1.0.5:
version "1.0.5"
resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.5.tgz#3c1adf871ea9cd6c9431b6ea2bd74a0ff055c4c1"
path-parse@^1.0.6:
version "1.0.6"
resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.6.tgz#d62dbb5679405d72c4737ec58600e9ddcf06d24c"
...
...
@@ -3834,7 +3800,7 @@ readable-stream@~2.0.0:
string_decoder "~0.10.x"
util-deprecate "~1.0.1"
readdir-scoped-modules@*,
readdir-scoped-modules@^1.0.0:
readdir-scoped-modules@^1.0.0:
version "1.1.0"
resolved "https://registry.yarnpkg.com/readdir-scoped-modules/-/readdir-scoped-modules-1.1.0.tgz#8d45407b4f870a0dcaebc0e28670d18e74514309"
dependencies:
...
...
@@ -3977,12 +3943,6 @@ resolve@^1.10.0:
dependencies:
path-parse "^1.0.6"
resolve@^1.3.2:
version "1.8.1"
resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.8.1.tgz#82f1ec19a423ac1fbd080b0bab06ba36e84a7a26"
dependencies:
path-parse "^1.0.5"
responselike@1.0.2:
version "1.0.2"
resolved "https://registry.yarnpkg.com/responselike/-/responselike-1.0.2.tgz#918720ef3b631c5642be068f15ade5a46f4ba1e7"
...
...
@@ -4599,7 +4559,7 @@ ts-node@^7.0.0:
source-map-support "^0.5.6"
yn "^2.0.0"
tslib@^1.8.0,
tslib@^1.8.1:
tslib@^1.8.1:
version "1.9.3"
resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.9.3.tgz#d7e4dd79245d85428c4d7e4822a79917954ca286"
...
...
@@ -4607,42 +4567,6 @@ tslib@^1.9.0:
version "1.10.0"
resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.10.0.tgz#c3c19f95973fb0a62973fb09d90d961ee43e5c8a"
tslint-microsoft-contrib@^6.0.0:
version "6.2.0"
resolved "https://registry.yarnpkg.com/tslint-microsoft-contrib/-/tslint-microsoft-contrib-6.2.0.tgz#8aa0f40584d066d05e6a5e7988da5163b85f2ad4"
dependencies:
tsutils "^2.27.2 <2.29.0"
tslint@^5.12.0:
version "5.18.0"
resolved "https://registry.yarnpkg.com/tslint/-/tslint-5.18.0.tgz#f61a6ddcf372344ac5e41708095bbf043a147ac6"
dependencies:
"@babel/code-frame" "^7.0.0"
builtin-modules "^1.1.1"
chalk "^2.3.0"
commander "^2.12.1"
diff "^3.2.0"
glob "^7.1.1"
js-yaml "^3.13.1"
minimatch "^3.0.4"
mkdirp "^0.5.1"
resolve "^1.3.2"
semver "^5.3.0"
tslib "^1.8.0"
tsutils "^2.29.0"
"tsutils@^2.27.2 <2.29.0":
version "2.28.0"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-2.28.0.tgz#6bd71e160828f9d019b6f4e844742228f85169a1"
dependencies:
tslib "^1.8.1"
tsutils@^2.29.0:
version "2.29.0"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-2.29.0.tgz#32b488501467acbedd4b85498673a0812aca0b99"
dependencies:
tslib "^1.8.1"
tsutils@^3.17.1:
version "3.17.1"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-3.17.1.tgz#ed719917f11ca0dee586272b2ac49e015a2dd759"
...
...
@@ -4818,7 +4742,7 @@ v8-compile-cache@^2.0.3:
version "2.1.0"
resolved "https://registry.yarnpkg.com/v8-compile-cache/-/v8-compile-cache-2.1.0.tgz#e14de37b31a6d194f5690d67efc4e7f6fc6ab30e"
validate-npm-package-license@*,
validate-npm-package-license@^3.0.1:
validate-npm-package-license@^3.0.1:
version "3.0.4"
resolved "https://registry.yarnpkg.com/validate-npm-package-license/-/validate-npm-package-license-3.0.4.tgz#fc91f6b9c7ba15c857f4cb2c5defeec39d4f410a"
dependencies:
...
...
src/sdk/pynni/nni/batch_tuner/batch_tuner.py
View file @
543239c6
...
...
@@ -24,13 +24,15 @@ class BatchTuner(Tuner):
Examples
--------
The search space only be accepted like:
```
{
'combine_params': { '_type': 'choice',
::
{'combine_params':
{ '_type': 'choice',
'_value': '[{...}, {...}, {...}]',
}
}
```
"""
def
__init__
(
self
):
...
...
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
View file @
543239c6
...
...
@@ -5,7 +5,7 @@ import logging
import
torch
from
.compressor
import
Pruner
__all__
=
[
'LevelPruner'
,
'AGP_Pruner'
,
'
FPGM
Pruner'
,
'L1FilterPruner'
,
'
Slim
Pruner'
]
__all__
=
[
'LevelPruner'
,
'AGP_Pruner'
,
'
Slim
Pruner'
,
'L1FilterPruner'
,
'
L2FilterPruner'
,
'FPGM
Pruner'
]
logger
=
logging
.
getLogger
(
'torch pruner'
)
...
...
@@ -166,119 +166,132 @@ class AGP_Pruner(Pruner):
self
.
if_init_list
[
k
]
=
True
class
FPGM
Pruner
(
Pruner
):
class
Slim
Pruner
(
Pruner
):
"""
A filter pruner via geometric median.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration",
https://arxiv.org/pdf/1811.00250.pdf
A structured pruning algorithm that prunes channels by pruning the weights of BN layers.
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang
"Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
"""
def
__init__
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : pytorch model
the model user wants to compress
config_list: list
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_dict
=
{}
self
.
epoch_pruned_layers
=
set
()
self
.
mask_calculated_ops
=
set
()
weight_list
=
[]
if
len
(
config_list
)
>
1
:
logger
.
warning
(
'Slim pruner only supports 1 configuration'
)
config
=
config_list
[
0
]
for
(
layer
,
config
)
in
self
.
detect_modules_to_compress
():
assert
layer
.
type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
weight_list
.
append
(
layer
.
module
.
weight
.
data
.
abs
().
clone
())
all_bn_weights
=
torch
.
cat
(
weight_list
)
k
=
int
(
all_bn_weights
.
shape
[
0
]
*
config
[
'sparsity'
])
self
.
global_threshold
=
torch
.
topk
(
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
def
calc_mask
(
self
,
layer
,
config
):
"""
Supports Conv1d, Conv2d
filter dimensions for Conv1d:
OUT: number of output channel
IN: number of input channel
LEN: filter length
filter dimensions for Conv2d:
OUT: number of output channel
IN: number of input channel
H: filter height
W: filter width
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
layer : LayerInfo
calculate mask for `layer`'s weight
the layer to instrument the compression operation
config : dict
the configuration for generating the mask
layer's pruning config
Returns
-------
torch.Tensor
mask of the layer's weight
"""
weight
=
layer
.
module
.
weight
.
data
assert
0
<=
config
.
get
(
'sparsity'
)
<
1
assert
layer
.
type
in
[
'Conv1d'
,
'Conv2d'
]
assert
layer
.
type
in
config
[
'op_types'
]
op_name
=
layer
.
name
op_type
=
layer
.
type
assert
op_type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
if
op_name
in
self
.
mask_calculated_ops
:
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
mask
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
)
try
:
w_abs
=
weight
.
abs
()
mask
=
torch
.
gt
(
w_abs
,
self
.
global_threshold
).
type_as
(
weight
)
finally
:
self
.
mask_dict
.
update
({
layer
.
name
:
mask
})
self
.
mask_calculated_ops
.
add
(
layer
.
name
)
if
layer
.
name
in
self
.
epoch_pruned_layers
:
assert
layer
.
name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
layer
.
name
)
return
mask
masks
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
)
try
:
num_filters
=
weight
.
size
(
0
)
num_prune
=
int
(
num_filters
*
config
.
get
(
'sparsity'
))
if
num_filters
<
2
or
num_prune
<
1
:
return
masks
min_gm_idx
=
self
.
_get_min_gm_kernel_idx
(
weight
,
num_prune
)
for
idx
in
min_gm_idx
:
masks
[
idx
]
=
0.
finally
:
self
.
mask_dict
.
update
({
layer
.
name
:
masks
})
self
.
epoch_pruned_layers
.
add
(
layer
.
name
)
class
RankFilterPruner
(
Pruner
):
"""
A structured pruning base class that prunes the filters with the smallest
importance criterion in convolution layers to achieve a preset level of network sparsity.
"""
return
masks
def
__init__
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
def
_get_min_gm_kernel_idx
(
s
el
f
,
weight
,
n
):
assert
len
(
weight
.
size
())
in
[
3
,
4
]
super
().
__init__
(
mod
el
,
config_list
)
self
.
mask_calculated_ops
=
set
()
dist_list
=
[]
for
out_i
in
range
(
weight
.
size
(
0
)):
dist_sum
=
self
.
_get_distance_sum
(
weight
,
out_i
)
dist_list
.
append
((
dist_sum
,
out_i
))
min_gm_kernels
=
sorted
(
dist_list
,
key
=
lambda
x
:
x
[
0
])[:
n
]
return
[
x
[
1
]
for
x
in
min_gm_kernels
]
def
_get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
return
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
)
def
_get_distance_sum
(
self
,
weight
,
out_idx
):
def
calc_mask
(
self
,
layer
,
config
):
"""
Calculate the total distance between a specified filter (by out_idex and in_idx) and
all other filters.
Optimized verision of following naive implementation:
def _get_distance_sum(self, weight, in_idx, out_idx):
w = weight.view(-1, weight.size(-2), weight.size(-1))
dist_sum = 0.
for k in w:
dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2)
return dist_sum
Calculate the mask of given layer.
Filters with the smallest importance criterion of the kernel weights are masked.
Parameters
----------
weight: Tensor
convolutional filter weight
out_idx: int
output channel index of specified filter, this method calculates the total distance
between this specified filter and all other filters.
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
float32
The total distance
torch.Tensor
mask of the layer's weight
"""
logger
.
debug
(
'weight size: %s'
,
weight
.
size
())
assert
len
(
weight
.
size
())
in
[
3
,
4
],
'unsupported weight shape'
w
=
weight
.
view
(
weight
.
size
(
0
),
-
1
)
anchor_w
=
w
[
out_idx
].
unsqueeze
(
0
).
expand
(
w
.
size
(
0
),
w
.
size
(
1
))
x
=
w
-
anchor_w
x
=
(
x
*
x
).
sum
(
-
1
)
x
=
torch
.
sqrt
(
x
)
return
x
.
sum
()
def
update_epoch
(
self
,
epoch
):
self
.
epoch_pruned_layers
=
set
()
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
op_type
=
layer
.
type
assert
0
<=
config
.
get
(
'sparsity'
)
<
1
assert
op_type
in
[
'Conv1d'
,
'Conv2d'
]
assert
op_type
in
config
.
get
(
'op_types'
)
if
op_name
in
self
.
mask_calculated_ops
:
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
mask
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
)
try
:
filters
=
weight
.
size
(
0
)
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
if
filters
<
2
or
num_prune
<
1
:
return
mask
mask
=
self
.
_get_mask
(
mask
,
weight
,
num_prune
)
finally
:
self
.
mask_dict
.
update
({
op_name
:
mask
})
self
.
mask_calculated_ops
.
add
(
op_name
)
return
mask
.
detach
()
class
L1FilterPruner
(
Pruner
):
class
L1FilterPruner
(
RankFilter
Pruner
):
"""
A structured pruning algorithm that prunes the filters of smallest magnitude
weights sum in the convolution layers to achieve a preset level of network sparsity.
...
...
@@ -299,107 +312,162 @@ class L1FilterPruner(Pruner):
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
def
calc
_mask
(
self
,
layer
,
config
):
def
_get
_mask
(
self
,
base_mask
,
weight
,
num_prune
):
"""
Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
base_mask : torch.Tensor
The basic mask with the same shape of weight, all item in the basic mask is 1.
weight : torch.Tensor
Layer's weight
num_prune : int
Num of filters to prune
Returns
-------
torch.Tensor
m
ask of the layer's weight
M
ask of the layer's weight
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
op_type
=
layer
.
type
assert
op_type
==
'Conv2d'
,
'L1FilterPruner only supports 2d convolution layer pruning'
if
op_name
in
self
.
mask_calculated_ops
:
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
mask
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
)
try
:
filters
=
weight
.
shape
[
0
]
w_abs
=
weight
.
abs
()
k
=
int
(
filters
*
config
[
'sparsity'
])
if
k
==
0
:
return
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
w_abs_structured
=
w_abs
.
view
(
filters
,
-
1
).
sum
(
dim
=
1
)
threshold
=
torch
.
topk
(
w_abs_structured
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
threshold
=
torch
.
topk
(
w_abs_structured
.
view
(
-
1
),
num_prune
,
largest
=
False
)[
0
].
max
()
mask
=
torch
.
gt
(
w_abs_structured
,
threshold
)[:,
None
,
None
,
None
].
expand_as
(
weight
).
type_as
(
weight
)
finally
:
self
.
mask_dict
.
update
({
layer
.
name
:
mask
})
self
.
mask_calculated_ops
.
add
(
layer
.
name
)
return
mask
class
SlimPruner
(
Pruner
):
class
L2FilterPruner
(
RankFilter
Pruner
):
"""
A structured pruning algorithm that prunes channels by pruning the weights of BN layers.
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang
"Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
A structured pruning algorithm that prunes the filters with the
smallest L2 norm of the absolute kernel weights are masked.
"""
def
__init__
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
weight_list
=
[]
if
len
(
config_list
)
>
1
:
logger
.
warning
(
'Slim pruner only supports 1 configuration'
)
config
=
config_list
[
0
]
for
(
layer
,
config
)
in
self
.
detect_modules_to_compress
():
assert
layer
.
type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
weight_list
.
append
(
layer
.
module
.
weight
.
data
.
abs
().
clone
())
all_bn_weights
=
torch
.
cat
(
weight_list
)
k
=
int
(
all_bn_weights
.
shape
[
0
]
*
config
[
'sparsity'
])
self
.
global_threshold
=
torch
.
topk
(
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
def
calc
_mask
(
self
,
layer
,
config
):
def
_get
_mask
(
self
,
base_mask
,
weight
,
num_prune
):
"""
Calculate the mask of given layer.
Scale facto
rs with the smallest
absolute value in the BN layer
are masked.
Filte
rs with the smallest
L2 norm of the absolute kernel weights
are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
base_mask : torch.Tensor
The basic mask with the same shape of weight, all item in the basic mask is 1.
weight : torch.Tensor
Layer's weight
num_prune : int
Num of filters to prune
Returns
-------
torch.Tensor
m
ask of the layer's weight
M
ask of the layer's weight
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
op_type
=
layer
.
type
assert
op_type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
if
op_name
in
self
.
mask_calculated_ops
:
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
mask
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
)
try
:
w_abs
=
weight
.
abs
()
mask
=
torch
.
gt
(
w_abs
,
self
.
global_threshold
).
type_as
(
weight
)
finally
:
self
.
mask_dict
.
update
({
layer
.
name
:
mask
})
self
.
mask_calculated_ops
.
add
(
layer
.
name
)
filters
=
weight
.
shape
[
0
]
w
=
weight
.
view
(
filters
,
-
1
)
w_l2_norm
=
torch
.
sqrt
((
w
**
2
).
sum
(
dim
=
1
))
threshold
=
torch
.
topk
(
w_l2_norm
.
view
(
-
1
),
num_prune
,
largest
=
False
)[
0
].
max
()
mask
=
torch
.
gt
(
w_l2_norm
,
threshold
)[:,
None
,
None
,
None
].
expand_as
(
weight
).
type_as
(
weight
)
return
mask
class
FPGMPruner
(
RankFilterPruner
):
"""
A filter pruner via geometric median.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration",
https://arxiv.org/pdf/1811.00250.pdf
"""
def
__init__
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : pytorch model
the model user wants to compress
config_list: list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super
().
__init__
(
model
,
config_list
)
def
_get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
"""
Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked.
Parameters
----------
base_mask : torch.Tensor
The basic mask with the same shape of weight, all item in the basic mask is 1.
weight : torch.Tensor
Layer's weight
num_prune : int
Num of filters to prune
Returns
-------
torch.Tensor
Mask of the layer's weight
"""
min_gm_idx
=
self
.
_get_min_gm_kernel_idx
(
weight
,
num_prune
)
for
idx
in
min_gm_idx
:
base_mask
[
idx
]
=
0.
return
base_mask
def
_get_min_gm_kernel_idx
(
self
,
weight
,
n
):
assert
len
(
weight
.
size
())
in
[
3
,
4
]
dist_list
=
[]
for
out_i
in
range
(
weight
.
size
(
0
)):
dist_sum
=
self
.
_get_distance_sum
(
weight
,
out_i
)
dist_list
.
append
((
dist_sum
,
out_i
))
min_gm_kernels
=
sorted
(
dist_list
,
key
=
lambda
x
:
x
[
0
])[:
n
]
return
[
x
[
1
]
for
x
in
min_gm_kernels
]
def
_get_distance_sum
(
self
,
weight
,
out_idx
):
"""
Calculate the total distance between a specified filter (by out_idex and in_idx) and
all other filters.
Optimized verision of following naive implementation:
def _get_distance_sum(self, weight, in_idx, out_idx):
w = weight.view(-1, weight.size(-2), weight.size(-1))
dist_sum = 0.
for k in w:
dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2)
return dist_sum
Parameters
----------
weight: Tensor
convolutional filter weight
out_idx: int
output channel index of specified filter, this method calculates the total distance
between this specified filter and all other filters.
Returns
-------
float32
The total distance
"""
logger
.
debug
(
'weight size: %s'
,
weight
.
size
())
assert
len
(
weight
.
size
())
in
[
3
,
4
],
'unsupported weight shape'
w
=
weight
.
view
(
weight
.
size
(
0
),
-
1
)
anchor_w
=
w
[
out_idx
].
unsqueeze
(
0
).
expand
(
w
.
size
(
0
),
w
.
size
(
1
))
x
=
w
-
anchor_w
x
=
(
x
*
x
).
sum
(
-
1
)
x
=
torch
.
sqrt
(
x
)
return
x
.
sum
()
def
update_epoch
(
self
,
epoch
):
self
.
mask_calculated_ops
=
set
()
src/sdk/pynni/nni/msg_dispatcher_base.py
View file @
543239c6
...
...
@@ -163,12 +163,15 @@ class MsgDispatcherBase(Recoverable):
raise
NotImplementedError
(
'handle_initialize not implemented'
)
def
handle_request_trial_jobs
(
self
,
data
):
"""The message dispatcher is demanded to generate `data` trial jobs.
These trial jobs should be sent via `send(CommandType.NewTrialJob, json_tricks.dumps(parameter))`,
where `parameter` will be received by NNI Manager and eventually accessible to trial jobs as "next parameter".
Semantically, message dispatcher should do this `send` exactly `data` times.
"""The message dispatcher is demanded to generate
`
`data`
`
trial jobs.
These trial jobs should be sent via
`
`send(CommandType.NewTrialJob, json_tricks.dumps(parameter))`
`
,
where
`
`parameter`
`
will be received by NNI Manager and eventually accessible to trial jobs as "next parameter".
Semantically, message dispatcher should do this
`
`send`
`
exactly
`
`data`
`
times.
The JSON sent by this method should follow the format of
::
{
"parameter_id": 42
"parameters": {
...
...
@@ -176,6 +179,7 @@ class MsgDispatcherBase(Recoverable):
},
"parameter_source": "algorithm" // optional
}
Parameters
----------
data: int
...
...
@@ -211,6 +215,7 @@ class MsgDispatcherBase(Recoverable):
def
handle_report_metric_data
(
self
,
data
):
"""Called when metric data is reported or new parameters are requested (for multiphase).
When new parameters are requested, this method should send a new parameter.
Parameters
----------
data: dict
...
...
@@ -219,6 +224,7 @@ class MsgDispatcherBase(Recoverable):
`REQUEST_PARAMETER` is used to request new parameters for multiphase trial job. In this case,
the dict will contain additional keys: `trial_job_id`, `parameter_index`. Refer to `msg_dispatcher.py`
as an example.
Raises
------
ValueError
...
...
@@ -228,6 +234,7 @@ class MsgDispatcherBase(Recoverable):
def
handle_trial_end
(
self
,
data
):
"""Called when the state of one of the trials is changed
Parameters
----------
data: dict
...
...
@@ -235,5 +242,6 @@ class MsgDispatcherBase(Recoverable):
trial_job_id: the id generated by training service.
event: the job’s state.
hyper_params: the string that is sent by message dispatcher during the creation of trials.
"""
raise
NotImplementedError
(
'handle_trial_end not implemented'
)
src/sdk/pynni/tests/test_compressor.py
View file @
543239c6
...
...
@@ -58,8 +58,9 @@ def tf2(func):
return
test_tf2_func
# for fpgm filter pruner test
w
=
np
.
array
([[[[
i
+
1
]
*
3
]
*
3
]
*
5
for
i
in
range
(
10
)])
w
=
np
.
array
([[[[
i
+
1
]
*
3
]
*
3
]
*
5
for
i
in
range
(
10
)])
class
CompressorTestCase
(
TestCase
):
...
...
@@ -69,19 +70,19 @@ class CompressorTestCase(TestCase):
config_list
=
[{
'quant_types'
:
[
'weight'
],
'quant_bits'
:
8
,
'op_types'
:[
'Conv2d'
,
'Linear'
]
'op_types'
:
[
'Conv2d'
,
'Linear'
]
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
8
,
'quant_start_step'
:
0
,
'op_types'
:[
'ReLU'
]
'op_types'
:
[
'ReLU'
]
}]
model
.
relu
=
torch
.
nn
.
ReLU
()
quantizer
=
torch_compressor
.
QAT_Quantizer
(
model
,
config_list
)
quantizer
.
compress
()
modules_to_compress
=
quantizer
.
get_modules_to_compress
()
modules_to_compress_name
=
[
t
[
0
].
name
for
t
in
modules_to_compress
]
modules_to_compress_name
=
[
t
[
0
].
name
for
t
in
modules_to_compress
]
assert
"conv1"
in
modules_to_compress_name
assert
"conv2"
in
modules_to_compress_name
assert
"fc1"
in
modules_to_compress_name
...
...
@@ -179,7 +180,8 @@ class CompressorTestCase(TestCase):
w
=
np
.
array
([
np
.
zeros
((
3
,
3
,
3
)),
np
.
ones
((
3
,
3
,
3
)),
np
.
ones
((
3
,
3
,
3
))
*
2
,
np
.
ones
((
3
,
3
,
3
))
*
3
,
np
.
ones
((
3
,
3
,
3
))
*
4
])
model
=
TorchModel
()
config_list
=
[{
'sparsity'
:
0.2
,
'op_names'
:
[
'conv1'
]},
{
'sparsity'
:
0.6
,
'op_names'
:
[
'conv2'
]}]
config_list
=
[{
'sparsity'
:
0.2
,
'op_types'
:
[
'Conv2d'
],
'op_names'
:
[
'conv1'
]},
{
'sparsity'
:
0.6
,
'op_types'
:
[
'Conv2d'
],
'op_names'
:
[
'conv2'
]}]
pruner
=
torch_compressor
.
L1FilterPruner
(
model
,
config_list
)
model
.
conv1
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
...
...
@@ -236,12 +238,12 @@ class CompressorTestCase(TestCase):
config_list
=
[{
'quant_types'
:
[
'weight'
],
'quant_bits'
:
8
,
'op_types'
:[
'Conv2d'
,
'Linear'
]
'op_types'
:
[
'Conv2d'
,
'Linear'
]
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
8
,
'quant_start_step'
:
0
,
'op_types'
:[
'ReLU'
]
'op_types'
:
[
'ReLU'
]
}]
model
.
relu
=
torch
.
nn
.
ReLU
()
quantizer
=
torch_compressor
.
QAT_Quantizer
(
model
,
config_list
)
...
...
@@ -271,5 +273,6 @@ class CompressorTestCase(TestCase):
assert
math
.
isclose
(
model
.
relu
.
tracked_min_biased
,
0.002
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
tracked_max_biased
,
0.00998
,
abs_tol
=
eps
)
if
__name__
==
'__main__'
:
main
()
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment