Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
bf8be1e7
Unverified
Commit
bf8be1e7
authored
Aug 28, 2020
by
Yuge Zhang
Committed by
GitHub
Aug 28, 2020
Browse files
Merge pull request #2837 from microsoft/v1.8
Merge v1.8 back to master
parents
320407b1
e06a9dda
Changes
44
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
305 additions
and
127 deletions
+305
-127
src/nni_manager/training_service/local/localTrainingService.ts
...ni_manager/training_service/local/localTrainingService.ts
+3
-1
src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts
...ng_service/remote_machine/remoteMachineTrainingService.ts
+25
-23
src/nni_manager/training_service/reusable/aml/amlClient.ts
src/nni_manager/training_service/reusable/aml/amlClient.ts
+21
-15
src/nni_manager/training_service/reusable/test/amlClient.test.ts
..._manager/training_service/reusable/test/amlClient.test.ts
+29
-0
src/sdk/pynni/nni/compression/tensorflow/compressor.py
src/sdk/pynni/nni/compression/tensorflow/compressor.py
+38
-31
src/sdk/pynni/nni/compression/tensorflow/pruning/one_shot.py
src/sdk/pynni/nni/compression/tensorflow/pruning/one_shot.py
+10
-6
src/sdk/pynni/nni/compression/torch/pruning/sensitivity_pruner.py
...pynni/nni/compression/torch/pruning/sensitivity_pruner.py
+18
-2
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
+5
-1
src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py
src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py
+4
-1
src/sdk/pynni/nni/compression/torch/utils/sensitivity_analysis.py
...pynni/nni/compression/torch/utils/sensitivity_analysis.py
+7
-5
src/sdk/pynni/nni/nas/benchmarks/nasbench101/query.py
src/sdk/pynni/nni/nas/benchmarks/nasbench101/query.py
+1
-1
src/sdk/pynni/nni/nas/benchmarks/nasbench201/query.py
src/sdk/pynni/nni/nas/benchmarks/nasbench201/query.py
+1
-1
src/sdk/pynni/nni/nas/pytorch/mutator.py
src/sdk/pynni/nni/nas/pytorch/mutator.py
+2
-2
src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_cell.py
src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_cell.py
+2
-3
src/sdk/pynni/nni/package_utils.py
src/sdk/pynni/nni/package_utils.py
+14
-26
src/sdk/pynni/tests/test_compressor_tf.py
src/sdk/pynni/tests/test_compressor_tf.py
+97
-0
src/sdk/pynni/tests/test_model_speedup.py
src/sdk/pynni/tests/test_model_speedup.py
+1
-1
src/webui/src/components/Modals/Compare.tsx
src/webui/src/components/Modals/Compare.tsx
+18
-3
src/webui/src/components/public-child/PaiTrialChild.tsx
src/webui/src/components/public-child/PaiTrialChild.tsx
+7
-3
src/webui/src/components/public-child/PaiTrialLog.tsx
src/webui/src/components/public-child/PaiTrialLog.tsx
+2
-2
No files found.
src/nni_manager/training_service/local/localTrainingService.ts
View file @
bf8be1e7
...
@@ -491,7 +491,7 @@ class LocalTrainingService implements TrainingService {
...
@@ -491,7 +491,7 @@ class LocalTrainingService implements TrainingService {
if
(
process
.
platform
===
'
win32
'
)
{
if
(
process
.
platform
===
'
win32
'
)
{
script
.
push
(
`cd $env:NNI_CODE_DIR`
);
script
.
push
(
`cd $env:NNI_CODE_DIR`
);
script
.
push
(
script
.
push
(
`cmd.exe /c
${
localTrialConfig
.
command
}
2>"
${
path
.
join
(
workingDirectory
,
'
stderr
'
)}
"`
,
`cmd.exe /c
${
localTrialConfig
.
command
}
2>
&1 | Out-File
"
${
path
.
join
(
workingDirectory
,
'
stderr
'
)}
"
-encoding utf8
`
,
`$NOW_DATE = [int64](([datetime]::UtcNow)-(get-date "1/1/1970")).TotalSeconds`
,
`$NOW_DATE = [int64](([datetime]::UtcNow)-(get-date "1/1/1970")).TotalSeconds`
,
`$NOW_DATE = "$NOW_DATE" + (Get-Date -Format fff).ToString()`
,
`$NOW_DATE = "$NOW_DATE" + (Get-Date -Format fff).ToString()`
,
`Write $LASTEXITCODE " " $NOW_DATE | Out-File "
${
path
.
join
(
workingDirectory
,
'
.nni
'
,
'
state
'
)}
" -NoNewline -encoding utf8`
);
`Write $LASTEXITCODE " " $NOW_DATE | Out-File "
${
path
.
join
(
workingDirectory
,
'
.nni
'
,
'
state
'
)}
" -NoNewline -encoding utf8`
);
...
@@ -523,6 +523,8 @@ class LocalTrainingService implements TrainingService {
...
@@ -523,6 +523,8 @@ class LocalTrainingService implements TrainingService {
const
runScriptContent
:
string
[]
=
[];
const
runScriptContent
:
string
[]
=
[];
if
(
process
.
platform
!==
'
win32
'
)
{
if
(
process
.
platform
!==
'
win32
'
)
{
runScriptContent
.
push
(
'
#!/bin/bash
'
);
runScriptContent
.
push
(
'
#!/bin/bash
'
);
}
else
{
runScriptContent
.
push
(
`$env:PATH="
${
process
.
env
.
path
}
"`
)
}
}
for
(
const
variable
of
variables
)
{
for
(
const
variable
of
variables
)
{
runScriptContent
.
push
(
setEnvironmentVariable
(
variable
));
runScriptContent
.
push
(
setEnvironmentVariable
(
variable
));
...
...
src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts
View file @
bf8be1e7
...
@@ -87,6 +87,21 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -87,6 +87,21 @@ class RemoteMachineTrainingService implements TrainingService {
this
.
log
.
info
(
'
ssh connection initialized!
'
);
this
.
log
.
info
(
'
ssh connection initialized!
'
);
// set sshConnectionPromises to [] to avoid log information duplicated
// set sshConnectionPromises to [] to avoid log information duplicated
this
.
sshConnectionPromises
=
[];
this
.
sshConnectionPromises
=
[];
// initialize gpuScheduler
this
.
gpuScheduler
=
new
GPUScheduler
(
this
.
machineExecutorManagerMap
);
if
(
this
.
trialConfig
===
undefined
)
{
throw
new
Error
(
"
trial config not initialized!
"
);
}
// Copy codeDir to remote machine
for
(
const
[
rmMeta
,
executorManager
]
of
this
.
machineExecutorManagerMap
.
entries
())
{
const
executor
:
ShellExecutor
=
await
executorManager
.
getExecutor
(
this
.
initExecutorId
);
if
(
executor
!==
undefined
)
{
this
.
machineCopyExpCodeDirPromiseMap
.
set
(
rmMeta
,
executor
.
copyDirectoryToRemote
(
this
.
trialConfig
.
codeDir
,
executor
.
getRemoteCodePath
(
getExperimentId
()))
);
}
}
}
}
while
(
!
this
.
stopping
)
{
while
(
!
this
.
stopping
)
{
while
(
this
.
jobQueue
.
length
>
0
)
{
while
(
this
.
jobQueue
.
length
>
0
)
{
...
@@ -310,7 +325,6 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -310,7 +325,6 @@ class RemoteMachineTrainingService implements TrainingService {
break
;
break
;
case
TrialConfigMetadataKey
.
MACHINE_LIST
:
case
TrialConfigMetadataKey
.
MACHINE_LIST
:
await
this
.
setupConnections
(
value
);
await
this
.
setupConnections
(
value
);
this
.
gpuScheduler
=
new
GPUScheduler
(
this
.
machineExecutorManagerMap
);
break
;
break
;
case
TrialConfigMetadataKey
.
TRIAL_CONFIG
:
{
case
TrialConfigMetadataKey
.
TRIAL_CONFIG
:
{
const
remoteMachineTrailConfig
:
TrialConfig
=
<
TrialConfig
>
JSON
.
parse
(
value
);
const
remoteMachineTrailConfig
:
TrialConfig
=
<
TrialConfig
>
JSON
.
parse
(
value
);
...
@@ -327,20 +341,8 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -327,20 +341,8 @@ class RemoteMachineTrainingService implements TrainingService {
try
{
try
{
// Validate to make sure codeDir doesn't have too many files
// Validate to make sure codeDir doesn't have too many files
await
validateCodeDir
(
remoteMachineTrailConfig
.
codeDir
);
await
validateCodeDir
(
remoteMachineTrailConfig
.
codeDir
);
// Copy codeDir to remote machine
for
(
const
[
rmMeta
,
executorManager
]
of
this
.
machineExecutorManagerMap
.
entries
())
{
const
executor
:
ShellExecutor
=
await
executorManager
.
getExecutor
(
this
.
initExecutorId
);
if
(
executor
!==
undefined
)
{
this
.
machineCopyExpCodeDirPromiseMap
.
set
(
rmMeta
,
executor
.
copyDirectoryToRemote
(
remoteMachineTrailConfig
.
codeDir
,
executor
.
getRemoteCodePath
(
getExperimentId
()))
);
}
}
}
catch
(
error
)
{
}
catch
(
error
)
{
this
.
log
.
error
(
error
);
this
.
log
.
error
(
error
);
return
Promise
.
reject
(
new
Error
(
error
));
return
Promise
.
reject
(
new
Error
(
error
));
}
}
...
@@ -426,6 +428,11 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -426,6 +428,11 @@ class RemoteMachineTrainingService implements TrainingService {
const
rmMetaList
:
RemoteMachineMeta
[]
=
<
RemoteMachineMeta
[]
>
JSON
.
parse
(
machineList
);
const
rmMetaList
:
RemoteMachineMeta
[]
=
<
RemoteMachineMeta
[]
>
JSON
.
parse
(
machineList
);
for
(
const
rmMeta
of
rmMetaList
)
{
for
(
const
rmMeta
of
rmMetaList
)
{
this
.
sshConnectionPromises
.
push
(
this
.
initRemoteMachineOnConnected
(
rmMeta
));
}
}
private
async
initRemoteMachineOnConnected
(
rmMeta
:
RemoteMachineMeta
):
Promise
<
void
>
{
rmMeta
.
occupiedGpuIndexMap
=
new
Map
<
number
,
number
>
();
rmMeta
.
occupiedGpuIndexMap
=
new
Map
<
number
,
number
>
();
const
executorManager
:
ExecutorManager
=
new
ExecutorManager
(
rmMeta
);
const
executorManager
:
ExecutorManager
=
new
ExecutorManager
(
rmMeta
);
this
.
log
.
info
(
`connecting to
${
rmMeta
.
username
}
@
${
rmMeta
.
ip
}
:
${
rmMeta
.
port
}
`
);
this
.
log
.
info
(
`connecting to
${
rmMeta
.
username
}
@
${
rmMeta
.
ip
}
:
${
rmMeta
.
port
}
`
);
...
@@ -433,12 +440,7 @@ class RemoteMachineTrainingService implements TrainingService {
...
@@ -433,12 +440,7 @@ class RemoteMachineTrainingService implements TrainingService {
this
.
log
.
debug
(
`reached
${
executor
.
name
}
`
);
this
.
log
.
debug
(
`reached
${
executor
.
name
}
`
);
this
.
machineExecutorManagerMap
.
set
(
rmMeta
,
executorManager
);
this
.
machineExecutorManagerMap
.
set
(
rmMeta
,
executorManager
);
this
.
log
.
debug
(
`initializing
${
executor
.
name
}
`
);
this
.
log
.
debug
(
`initializing
${
executor
.
name
}
`
);
this
.
sshConnectionPromises
.
push
(
this
.
initRemoteMachineOnConnected
(
rmMeta
,
executor
));
this
.
log
.
info
(
`connecting to
${
executor
.
name
}
`
);
}
}
private
async
initRemoteMachineOnConnected
(
rmMeta
:
RemoteMachineMeta
,
executor
:
ShellExecutor
):
Promise
<
void
>
{
// Create root working directory after executor is ready
// Create root working directory after executor is ready
const
nniRootDir
:
string
=
executor
.
joinPath
(
executor
.
getTempPath
(),
'
nni
'
);
const
nniRootDir
:
string
=
executor
.
joinPath
(
executor
.
getTempPath
(),
'
nni
'
);
await
executor
.
createFolder
(
executor
.
getRemoteExperimentRootDir
(
getExperimentId
()));
await
executor
.
createFolder
(
executor
.
getRemoteExperimentRootDir
(
getExperimentId
()));
...
...
src/nni_manager/training_service/reusable/aml/amlClient.ts
View file @
bf8be1e7
...
@@ -74,13 +74,11 @@ export class AMLClient {
...
@@ -74,13 +74,11 @@ export class AMLClient {
throw
Error
(
'
python shell client not initialized!
'
);
throw
Error
(
'
python shell client not initialized!
'
);
}
}
this
.
pythonShellClient
.
send
(
'
tracking_url
'
);
this
.
pythonShellClient
.
send
(
'
tracking_url
'
);
let
trackingUrl
=
''
;
this
.
pythonShellClient
.
on
(
'
message
'
,
(
status
:
any
)
=>
{
this
.
pythonShellClient
.
on
(
'
message
'
,
function
(
status
:
any
)
{
const
trackingUrl
=
this
.
parseContent
(
'
tracking_url
'
,
status
);
const
items
=
status
.
split
(
'
:
'
);
if
(
trackingUrl
!==
''
)
{
if
(
items
[
0
]
===
'
tracking_url
'
)
{
trackingUrl
=
items
.
splice
(
1
,
items
.
length
).
join
(
''
)
}
deferred
.
resolve
(
trackingUrl
);
deferred
.
resolve
(
trackingUrl
);
}
});
});
this
.
monitorError
(
this
.
pythonShellClient
,
deferred
);
this
.
monitorError
(
this
.
pythonShellClient
,
deferred
);
return
deferred
.
promise
;
return
deferred
.
promise
;
...
@@ -91,12 +89,11 @@ export class AMLClient {
...
@@ -91,12 +89,11 @@ export class AMLClient {
if
(
this
.
pythonShellClient
===
undefined
)
{
if
(
this
.
pythonShellClient
===
undefined
)
{
throw
Error
(
'
python shell client not initialized!
'
);
throw
Error
(
'
python shell client not initialized!
'
);
}
}
let
newStatus
=
oldStatus
;
this
.
pythonShellClient
.
send
(
'
update_status
'
);
this
.
pythonShellClient
.
send
(
'
update_status
'
);
this
.
pythonShellClient
.
on
(
'
message
'
,
function
(
status
:
any
)
{
this
.
pythonShellClient
.
on
(
'
message
'
,
(
status
:
any
)
=>
{
const
items
=
status
.
split
(
'
:
'
);
let
newStatus
=
this
.
parseContent
(
'
status
'
,
status
);
if
(
items
[
0
]
===
'
status
'
)
{
if
(
newStatus
===
'
'
)
{
newStatus
=
items
.
splice
(
1
,
items
.
length
).
join
(
''
)
newStatus
=
oldStatus
;
}
}
deferred
.
resolve
(
newStatus
);
deferred
.
resolve
(
newStatus
);
});
});
...
@@ -117,10 +114,10 @@ export class AMLClient {
...
@@ -117,10 +114,10 @@ export class AMLClient {
throw
Error
(
'
python shell client not initialized!
'
);
throw
Error
(
'
python shell client not initialized!
'
);
}
}
this
.
pythonShellClient
.
send
(
'
receive
'
);
this
.
pythonShellClient
.
send
(
'
receive
'
);
this
.
pythonShellClient
.
on
(
'
message
'
,
function
(
command
:
any
)
{
this
.
pythonShellClient
.
on
(
'
message
'
,
(
command
:
any
)
=>
{
const
items
=
command
.
split
(
'
:
'
)
const
message
=
this
.
parseContent
(
'
receive
'
,
command
);
if
(
items
[
0
]
===
'
receive
'
)
{
if
(
message
!==
'
'
)
{
deferred
.
resolve
(
JSON
.
parse
(
command
.
slice
(
8
)
))
deferred
.
resolve
(
JSON
.
parse
(
message
))
}
}
});
});
this
.
monitorError
(
this
.
pythonShellClient
,
deferred
);
this
.
monitorError
(
this
.
pythonShellClient
,
deferred
);
...
@@ -136,4 +133,13 @@ export class AMLClient {
...
@@ -136,4 +133,13 @@ export class AMLClient {
deferred
.
reject
(
error
);
deferred
.
reject
(
error
);
});
});
}
}
// Parse command content, command format is {head}:{content}
public
parseContent
(
head
:
string
,
command
:
string
):
string
{
const
items
=
command
.
split
(
'
:
'
);
if
(
items
[
0
]
===
head
)
{
return
command
.
slice
(
head
.
length
+
1
);
}
return
''
;
}
}
}
src/nni_manager/training_service/reusable/test/amlClient.test.ts
0 → 100644
View file @
bf8be1e7
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import
*
as
chai
from
'
chai
'
;
import
{
cleanupUnitTest
,
prepareUnitTest
}
from
'
../../../common/utils
'
;
import
chaiAsPromised
=
require
(
"
chai-as-promised
"
);
import
{
AMLClient
}
from
'
../aml/amlClient
'
;
describe
(
'
Unit Test for amlClient
'
,
()
=>
{
before
(()
=>
{
chai
.
should
();
chai
.
use
(
chaiAsPromised
);
prepareUnitTest
();
});
after
(()
=>
{
cleanupUnitTest
();
});
it
(
'
test parseContent
'
,
async
()
=>
{
let
amlClient
:
AMLClient
=
new
AMLClient
(
''
,
''
,
''
,
''
,
''
,
''
,
''
,
''
);
chai
.
assert
.
equal
(
amlClient
.
parseContent
(
'
test
'
,
'
test:1234
'
),
'
1234
'
,
"
The content should be 1234
"
);
chai
.
assert
.
equal
(
amlClient
.
parseContent
(
'
test
'
,
'
abcd:1234
'
),
''
,
"
The content should be null
"
);
});
});
src/sdk/pynni/nni/compression/tensorflow/compressor.py
View file @
bf8be1e7
...
@@ -6,7 +6,10 @@ Abstract base classes for TensorFlow model compression.
...
@@ -6,7 +6,10 @@ Abstract base classes for TensorFlow model compression.
"""
"""
import
logging
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
assert
tf
.
__version__
.
startswith
(
'2'
),
'NNI model compression only supports TensorFlow v2.x'
from
.
import
default_layers
from
.
import
default_layers
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -25,9 +28,9 @@ class LayerInfo:
...
@@ -25,9 +28,9 @@ class LayerInfo:
The layer's name. Note that it's local to sub-model and may differ from its attribute name.
The layer's name. Note that it's local to sub-model and may differ from its attribute name.
type : str
type : str
Name of the layer's class.
Name of the layer's class.
path : list of str
/
int
path : list of str
or tuple of (str,
int
)
The layer object's and its parents' attribute name / list index.
The layer object's and its parents' attribute name / list index.
For example, if the path is `['cells', 2, 'conv']`, then the layer can be accessed as `model.cells[2].conv`.
For example, if the path is `[
(
'cells', 2
)
, 'conv']`, then the layer can be accessed as `model.cells[2].conv`.
config : JSON object
config : JSON object
Selected configuration for this layer. The format is detailed in tutorial.
Selected configuration for this layer. The format is detailed in tutorial.
...
@@ -35,7 +38,7 @@ class LayerInfo:
...
@@ -35,7 +38,7 @@ class LayerInfo:
----------
----------
layer : tf.keras.layers.Layer
layer : tf.keras.layers.Layer
See attributes section.
See attributes section.
path : list of str
/
int
path : list of str
or tuple of (str,
int
)
See attributes section.
See attributes section.
"""
"""
...
@@ -75,6 +78,8 @@ class Compressor:
...
@@ -75,6 +78,8 @@ class Compressor:
def
__init__
(
self
,
LayerWrapperClass
,
model
,
config_list
):
def
__init__
(
self
,
LayerWrapperClass
,
model
,
config_list
):
assert
isinstance
(
model
,
tf
.
keras
.
Model
)
assert
isinstance
(
model
,
tf
.
keras
.
Model
)
if
isinstance
(
model
,
tf
.
keras
.
Sequential
):
raise
ValueError
(
'NNI model compression does not support `Sequential` model for now'
)
self
.
validate_config
(
model
,
config_list
)
self
.
validate_config
(
model
,
config_list
)
self
.
bound_model
=
model
self
.
bound_model
=
model
...
@@ -204,10 +209,12 @@ class PrunerLayerWrapper(tf.keras.Model):
...
@@ -204,10 +209,12 @@ class PrunerLayerWrapper(tf.keras.Model):
for
weight
in
self
.
layer
.
weights
:
for
weight
in
self
.
layer
.
weights
:
mask
=
self
.
masks
.
get
(
weight
.
name
)
mask
=
self
.
masks
.
get
(
weight
.
name
)
if
mask
is
not
None
:
if
mask
is
not
None
:
new_weights
.
append
(
tf
.
math
.
multiply
(
weight
,
mask
)
.
numpy
()
)
new_weights
.
append
(
tf
.
math
.
multiply
(
weight
,
mask
))
else
:
else
:
new_weights
.
append
(
weight
.
numpy
())
new_weights
.
append
(
weight
)
self
.
layer
.
set_weights
(
new_weights
)
if
new_weights
and
not
hasattr
(
new_weights
[
0
],
'numpy'
):
raise
RuntimeError
(
'NNI: Compressed model can only run in eager mode'
)
self
.
layer
.
set_weights
([
weight
.
numpy
()
for
weight
in
new_weights
])
return
self
.
layer
(
*
inputs
)
return
self
.
layer
(
*
inputs
)
...
@@ -244,26 +251,21 @@ def _locate_layers(model, cur_path=[]):
...
@@ -244,26 +251,21 @@ def _locate_layers(model, cur_path=[]):
# and to my knowledge `Layer.name` is only useful for read-only access.
# and to my knowledge `Layer.name` is only useful for read-only access.
# `cur_path`s format is documented in `LayerInfo.path`.
# `cur_path`s format is documented in `LayerInfo.path`.
# TODO: it can only find layers in `Model` and `list` for now.
# TODO: it can only find layers in `Model` and `list` for now.
assert
isinstance
(
model
,
tf
.
keras
.
Model
)
if
isinstance
(
model
,
tf
.
keras
.
Sequential
):
_logger
.
warning
(
'`Sequential` model is not supported yet, ignored.'
)
ret
=
{}
ret
=
{}
if
isinstance
(
model
,
tf
.
keras
.
Model
):
for
key
,
value
in
model
.
__dict__
.
items
():
for
key
,
value
in
model
.
__dict__
.
items
():
if
isinstance
(
value
,
tf
.
keras
.
Model
):
if
isinstance
(
value
,
tf
.
keras
.
Model
):
ret
.
update
(
_locate_layers
(
value
,
cur_path
+
[
key
]))
ret
.
update
(
_locate_layers
(
value
,
cur_path
+
[
key
]))
elif
isinstance
(
value
,
list
):
ret
.
update
(
_locate_layers
(
value
,
cur_path
+
[
key
]))
elif
isinstance
(
value
,
tf
.
keras
.
layers
.
Layer
):
elif
isinstance
(
value
,
tf
.
keras
.
layers
.
Layer
):
ret
[
id
(
value
)]
=
LayerInfo
(
value
,
cur_path
+
[
key
])
ret
[
id
(
value
)]
=
LayerInfo
(
value
,
cur_path
+
[
key
])
elif
isinstance
(
value
,
list
):
elif
isinstance
(
model
,
list
):
for
i
,
item
in
enumerate
(
value
):
for
i
,
item
in
enumerate
(
model
):
if
isinstance
(
item
,
tf
.
keras
.
Model
):
if
isinstance
(
item
,
tf
.
keras
.
Model
):
ret
.
update
(
_locate_layers
(
item
,
cur_path
+
[
i
]))
ret
.
update
(
_locate_layers
(
item
,
cur_path
+
[
(
key
,
i
)
]))
elif
isinstance
(
item
,
tf
.
keras
.
layers
.
Layer
):
elif
isinstance
(
item
,
tf
.
keras
.
layers
.
Layer
):
ret
[
id
(
item
)]
=
LayerInfo
(
item
,
cur_path
+
[
i
])
ret
[
id
(
item
)]
=
LayerInfo
(
item
,
cur_path
+
[(
key
,
i
)])
else
:
raise
ValueError
(
'Unexpected model type: {}'
.
format
(
type
(
model
)))
return
ret
return
ret
def
_select_config
(
layer_info
,
config_list
):
def
_select_config
(
layer_info
,
config_list
):
...
@@ -289,12 +291,17 @@ def _instrument_model(model, wrappers):
...
@@ -289,12 +291,17 @@ def _instrument_model(model, wrappers):
for
wrapper
in
reversed
(
wrappers
):
for
wrapper
in
reversed
(
wrappers
):
cur
=
model
cur
=
model
for
key
in
wrapper
.
layer_info
.
path
[:
-
1
]:
for
key
in
wrapper
.
layer_info
.
path
[:
-
1
]:
if
isinstance
(
key
,
int
):
if
isinstance
(
key
,
str
):
cur
=
cur
[
key
]
else
:
cur
=
getattr
(
cur
,
key
)
cur
=
getattr
(
cur
,
key
)
key
=
wrapper
.
layer_info
.
path
[
-
1
]
if
isinstance
(
key
,
int
):
cur
[
key
]
=
wrapper
else
:
else
:
name
,
index
=
key
cur
=
getattr
(
cur
,
name
)[
index
]
key
=
wrapper
.
layer_info
.
path
[
-
1
]
if
isinstance
(
key
,
str
):
setattr
(
cur
,
key
,
wrapper
)
setattr
(
cur
,
key
,
wrapper
)
else
:
name
,
index
=
key
getattr
(
cur
,
name
)[
index
]
=
wrapper
#if isinstance(cur, tf.keras.Sequential):
# cur._graph_initialized = False
# cur._layer_call_argspecs[wrapper] = cur._layer_call_argspecs[wrapper.layer]
src/sdk/pynni/nni/compression/tensorflow/pruning/one_shot.py
View file @
bf8be1e7
...
@@ -44,20 +44,24 @@ class LevelPrunerMasker(WeightMasker):
...
@@ -44,20 +44,24 @@ class LevelPrunerMasker(WeightMasker):
def
calc_masks
(
self
,
sparsity
,
wrapper
,
wrapper_idx
=
None
):
def
calc_masks
(
self
,
sparsity
,
wrapper
,
wrapper_idx
=
None
):
masks
=
{}
masks
=
{}
for
weight_variable
in
wrapper
.
layer
.
weights
:
for
weight_variable
in
wrapper
.
layer
.
weights
:
if
weight_variable
.
name
==
'bias'
:
if
'bias'
in
weight_variable
.
name
:
continue
continue
k
=
int
(
tf
.
size
(
weight_variable
).
numpy
()
*
sparsity
)
num_prune
=
int
(
tf
.
size
(
weight_variable
).
numpy
()
*
sparsity
)
if
k
==
0
:
if
num_prune
==
0
:
continue
continue
weight
=
weight_variable
.
read_value
()
weight
=
weight_variable
.
read_value
()
if
wrapper
.
masks
.
get
(
weight_variable
.
name
)
is
not
None
:
if
wrapper
.
masks
.
get
(
weight_variable
.
name
)
is
not
None
:
weight
=
tf
.
math
.
multiply
(
weight
,
wrapper
.
masks
[
weight_variable
.
name
])
weight
=
tf
.
math
.
multiply
(
weight
,
wrapper
.
masks
[
weight_variable
.
name
])
w_abs
=
tf
.
math
.
abs
(
tf
.
reshape
(
weight
,
[
-
1
]))
w_abs
=
tf
.
math
.
abs
(
weight
)
threshold
=
tf
.
math
.
top_k
(
w_abs
,
k
)[
0
][
0
]
k
=
tf
.
size
(
weight
)
-
num_prune
mask
=
tf
.
math
.
greater
(
w_abs
,
threshold
)
topk
=
tf
.
math
.
top_k
(
tf
.
reshape
(
w_abs
,
[
-
1
]),
k
)[
0
]
if
tf
.
size
(
topk
)
==
0
:
mask
=
tf
.
zeros_like
(
weight
)
else
:
mask
=
tf
.
math
.
greater_equal
(
w_abs
,
topk
[
-
1
])
masks
[
weight_variable
.
name
]
=
tf
.
cast
(
mask
,
weight
.
dtype
)
masks
[
weight_variable
.
name
]
=
tf
.
cast
(
mask
,
weight
.
dtype
)
return
masks
return
masks
...
...
src/sdk/pynni/nni/compression/torch/pruning/sensitivity_pruner.py
View file @
bf8be1e7
...
@@ -17,7 +17,7 @@ from ..utils.sensitivity_analysis import SensitivityAnalysis
...
@@ -17,7 +17,7 @@ from ..utils.sensitivity_analysis import SensitivityAnalysis
MAX_PRUNE_RATIO_PER_ITER
=
0.95
MAX_PRUNE_RATIO_PER_ITER
=
0.95
_logger
=
logging
.
getLogger
(
'Sensitivity_Pruner'
)
_logger
=
logging
.
getLogger
(
'Sensitivity_Pruner'
)
_logger
.
setLevel
(
logging
.
INFO
)
class
SensitivityPruner
(
Pruner
):
class
SensitivityPruner
(
Pruner
):
"""
"""
...
@@ -202,10 +202,10 @@ class SensitivityPruner(Pruner):
...
@@ -202,10 +202,10 @@ class SensitivityPruner(Pruner):
prune_ratios
=
sorted
(
sensitivities
[
layer
].
keys
())
prune_ratios
=
sorted
(
sensitivities
[
layer
].
keys
())
last_ratio
=
0
last_ratio
=
0
for
ratio
in
prune_ratios
:
for
ratio
in
prune_ratios
:
last_ratio
=
ratio
cur_acc
=
sensitivities
[
layer
][
ratio
]
cur_acc
=
sensitivities
[
layer
][
ratio
]
if
cur_acc
+
threshold
<
ori_acc
:
if
cur_acc
+
threshold
<
ori_acc
:
break
break
last_ratio
=
ratio
max_ratio
[
layer
]
=
last_ratio
max_ratio
[
layer
]
=
last_ratio
return
max_ratio
return
max_ratio
...
@@ -244,6 +244,7 @@ class SensitivityPruner(Pruner):
...
@@ -244,6 +244,7 @@ class SensitivityPruner(Pruner):
# MAX_PRUNE_RATIO_PER_ITER we rescal all prune
# MAX_PRUNE_RATIO_PER_ITER we rescal all prune
# ratios under this threshold
# ratios under this threshold
if
_Max
>
MAX_PRUNE_RATIO_PER_ITER
:
if
_Max
>
MAX_PRUNE_RATIO_PER_ITER
:
for
layername
in
ratios
:
for
layername
in
ratios
:
ratios
[
layername
]
=
ratios
[
layername
]
*
\
ratios
[
layername
]
=
ratios
[
layername
]
*
\
MAX_PRUNE_RATIO_PER_ITER
/
_Max
MAX_PRUNE_RATIO_PER_ITER
/
_Max
...
@@ -317,6 +318,7 @@ class SensitivityPruner(Pruner):
...
@@ -317,6 +318,7 @@ class SensitivityPruner(Pruner):
finetune_kwargs
=
{}
finetune_kwargs
=
{}
if
self
.
ori_acc
is
None
:
if
self
.
ori_acc
is
None
:
self
.
ori_acc
=
self
.
evaluator
(
*
eval_args
,
**
eval_kwargs
)
self
.
ori_acc
=
self
.
evaluator
(
*
eval_args
,
**
eval_kwargs
)
assert
isinstance
(
self
.
ori_acc
,
float
)
or
isinstance
(
self
.
ori_acc
,
int
)
if
not
resume_sensitivity
:
if
not
resume_sensitivity
:
self
.
sensitivities
=
self
.
analyzer
.
analysis
(
self
.
sensitivities
=
self
.
analyzer
.
analysis
(
val_args
=
eval_args
,
val_kwargs
=
eval_kwargs
)
val_args
=
eval_args
,
val_kwargs
=
eval_kwargs
)
...
@@ -330,6 +332,7 @@ class SensitivityPruner(Pruner):
...
@@ -330,6 +332,7 @@ class SensitivityPruner(Pruner):
iteration_count
=
0
iteration_count
=
0
if
self
.
checkpoint_dir
is
not
None
:
if
self
.
checkpoint_dir
is
not
None
:
os
.
makedirs
(
self
.
checkpoint_dir
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
checkpoint_dir
,
exist_ok
=
True
)
modules_wrapper_final
=
None
while
cur_ratio
>
target_ratio
:
while
cur_ratio
>
target_ratio
:
iteration_count
+=
1
iteration_count
+=
1
# Each round have three steps:
# Each round have three steps:
...
@@ -343,9 +346,16 @@ class SensitivityPruner(Pruner):
...
@@ -343,9 +346,16 @@ class SensitivityPruner(Pruner):
# layers according to the sensitivity result
# layers according to the sensitivity result
proportion
=
self
.
sparsity_proportion_calc
(
proportion
=
self
.
sparsity_proportion_calc
(
ori_acc
,
self
.
acc_drop_threshold
,
self
.
sensitivities
)
ori_acc
,
self
.
acc_drop_threshold
,
self
.
sensitivities
)
new_pruneratio
=
self
.
normalize
(
proportion
,
self
.
sparsity_per_iter
)
new_pruneratio
=
self
.
normalize
(
proportion
,
self
.
sparsity_per_iter
)
cfg_list
=
self
.
create_cfg
(
new_pruneratio
)
cfg_list
=
self
.
create_cfg
(
new_pruneratio
)
if
not
cfg_list
:
_logger
.
error
(
'The threshold is too small, please set a larger threshold'
)
return
self
.
model
_logger
.
debug
(
'Pruner Config: %s'
,
str
(
cfg_list
))
_logger
.
debug
(
'Pruner Config: %s'
,
str
(
cfg_list
))
cfg_str
=
[
'%s:%.3f'
%
(
cfg
[
'op_names'
][
0
],
cfg
[
'sparsity'
])
for
cfg
in
cfg_list
]
_logger
.
info
(
'Current Sparsities: %s'
,
','
.
join
(
cfg_str
))
pruner
=
self
.
Pruner
(
self
.
model
,
cfg_list
)
pruner
=
self
.
Pruner
(
self
.
model
,
cfg_list
)
pruner
.
compress
()
pruner
.
compress
()
pruned_acc
=
self
.
evaluator
(
*
eval_args
,
**
eval_kwargs
)
pruned_acc
=
self
.
evaluator
(
*
eval_args
,
**
eval_kwargs
)
...
@@ -367,6 +377,7 @@ class SensitivityPruner(Pruner):
...
@@ -367,6 +377,7 @@ class SensitivityPruner(Pruner):
self
.
analyzer
.
already_pruned
[
name
]
=
sparsity
self
.
analyzer
.
already_pruned
[
name
]
=
sparsity
# update the cur_ratio
# update the cur_ratio
cur_ratio
=
1
-
self
.
current_sparsity
()
cur_ratio
=
1
-
self
.
current_sparsity
()
modules_wrapper_final
=
pruner
.
get_modules_wrapper
()
del
pruner
del
pruner
_logger
.
info
(
'Currently remained weights: %f'
,
cur_ratio
)
_logger
.
info
(
'Currently remained weights: %f'
,
cur_ratio
)
...
@@ -383,14 +394,19 @@ class SensitivityPruner(Pruner):
...
@@ -383,14 +394,19 @@ class SensitivityPruner(Pruner):
with
open
(
cfg_path
,
'w'
)
as
jf
:
with
open
(
cfg_path
,
'w'
)
as
jf
:
json
.
dump
(
cfg_list
,
jf
)
json
.
dump
(
cfg_list
,
jf
)
self
.
analyzer
.
export
(
sensitivity_path
)
self
.
analyzer
.
export
(
sensitivity_path
)
if
cur_ratio
>
target_ratio
:
if
cur_ratio
>
target_ratio
:
# If this is the last prune iteration, skip the time-consuming
# If this is the last prune iteration, skip the time-consuming
# sensitivity analysis
# sensitivity analysis
self
.
analyzer
.
load_state_dict
(
self
.
model
.
state_dict
())
self
.
analyzer
.
load_state_dict
(
self
.
model
.
state_dict
())
self
.
sensitivities
=
self
.
analyzer
.
analysis
(
self
.
sensitivities
=
self
.
analyzer
.
analysis
(
val_args
=
eval_args
,
val_kwargs
=
eval_kwargs
)
val_args
=
eval_args
,
val_kwargs
=
eval_kwargs
)
_logger
.
info
(
'After Pruning: %.2f weights remains'
,
cur_ratio
)
_logger
.
info
(
'After Pruning: %.2f weights remains'
,
cur_ratio
)
self
.
modules_wrapper
=
modules_wrapper_final
self
.
_wrap_model
()
return
self
.
model
return
self
.
model
def
calc_mask
(
self
,
wrapper
,
**
kwargs
):
def
calc_mask
(
self
,
wrapper
,
**
kwargs
):
...
...
src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
View file @
bf8be1e7
...
@@ -222,6 +222,10 @@ infer_from_inshape = {
...
@@ -222,6 +222,10 @@ infer_from_inshape = {
'ReLU'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'ReLU'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'ReLU6'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'ReLU6'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::tanh'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::tanh_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::hardtanh'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::hardtanh_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_inshape
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_inshape
(
module_masks
,
mask
),
'MaxPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'MaxPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
...
@@ -282,7 +286,7 @@ def cat_inshape(module_masks, mask, cat_info, last_visited):
...
@@ -282,7 +286,7 @@ def cat_inshape(module_masks, mask, cat_info, last_visited):
Parameters
Parameters
----------
----------
module_masks : ModuleMasks
module_masks : ModuleMasks
The ModuleMasks instance of the
batchnorm
2d
The ModuleMasks instance of the
Conv
2d
mask : CoarseMask
mask : CoarseMask
The mask of its input tensor
The mask of its input tensor
cat_info: dict
cat_info: dict
...
...
src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py
View file @
bf8be1e7
...
@@ -118,11 +118,14 @@ class CatMaskPadding(MaskFix):
...
@@ -118,11 +118,14 @@ class CatMaskPadding(MaskFix):
continue
continue
# pad the mask for the non-pruned layers
# pad the mask for the non-pruned layers
for
layer
in
layers
:
for
layer
in
layers
:
if
layer
in
self
.
masks
:
continue
module
=
name_to_module
[
layer
]
module
=
name_to_module
[
layer
]
w_shape
=
module
.
weight
.
data
.
size
()
w_shape
=
module
.
weight
.
data
.
size
()
w_mask
=
torch
.
ones
(
w_shape
).
to
(
device
)
w_mask
=
torch
.
ones
(
w_shape
).
to
(
device
)
b_mask
=
None
b_mask
=
None
if
hasattr
(
module
,
'bias'
):
if
hasattr
(
module
,
'bias'
)
and
module
.
bias
is
not
None
:
# module.bias may be None
b_shape
=
module
.
bias
.
data
.
size
()
b_shape
=
module
.
bias
.
data
.
size
()
b_mask
=
torch
.
ones
(
b_shape
).
to
(
device
)
b_mask
=
torch
.
ones
(
b_shape
).
to
(
device
)
self
.
masks
[
layer
]
=
{
'weight'
:
w_mask
,
'bias'
:
b_mask
}
self
.
masks
[
layer
]
=
{
'weight'
:
w_mask
,
'bias'
:
b_mask
}
...
...
src/sdk/pynni/nni/compression/torch/utils/sensitivity_analysis.py
View file @
bf8be1e7
...
@@ -163,7 +163,7 @@ class SensitivityAnalysis:
...
@@ -163,7 +163,7 @@ class SensitivityAnalysis:
if
val_kwargs
is
None
:
if
val_kwargs
is
None
:
val_kwargs
=
{}
val_kwargs
=
{}
# Get the original validation metric(accuracy/loss) before pruning
# Get the original validation metric(accuracy/loss) before pruning
if
self
.
ori_metric
is
None
:
# Get the accuracy baseline before starting the analysis.
self
.
ori_metric
=
self
.
val_func
(
*
val_args
,
**
val_kwargs
)
self
.
ori_metric
=
self
.
val_func
(
*
val_args
,
**
val_kwargs
)
namelist
=
list
(
self
.
target_layer
.
keys
())
namelist
=
list
(
self
.
target_layer
.
keys
())
if
specified_layers
is
not
None
:
if
specified_layers
is
not
None
:
...
@@ -172,19 +172,21 @@ class SensitivityAnalysis:
...
@@ -172,19 +172,21 @@ class SensitivityAnalysis:
for
name
in
namelist
:
for
name
in
namelist
:
self
.
sensitivities
[
name
]
=
{}
self
.
sensitivities
[
name
]
=
{}
for
sparsity
in
self
.
sparsities
:
for
sparsity
in
self
.
sparsities
:
# here the sparsity is the relative sparsity of the
# the remained weights
# Calculate the actual prune ratio based on the already pruned ratio
# Calculate the actual prune ratio based on the already pruned ratio
sparsity
=
(
real_
sparsity
=
(
1.0
-
self
.
already_pruned
[
name
])
*
sparsity
+
self
.
already_pruned
[
name
]
1.0
-
self
.
already_pruned
[
name
])
*
sparsity
+
self
.
already_pruned
[
name
]
# TODO In current L1/L2 Filter Pruner, the 'op_types' is still necessary
# TODO In current L1/L2 Filter Pruner, the 'op_types' is still necessary
# I think the L1/L2 Pruner should specify the op_types automaticlly
# I think the L1/L2 Pruner should specify the op_types automaticlly
# according to the op_names
# according to the op_names
cfg
=
[{
'sparsity'
:
sparsity
,
'op_names'
:
[
cfg
=
[{
'sparsity'
:
real_
sparsity
,
'op_names'
:
[
name
],
'op_types'
:
[
'Conv2d'
]}]
name
],
'op_types'
:
[
'Conv2d'
]}]
pruner
=
self
.
Pruner
(
self
.
model
,
cfg
)
pruner
=
self
.
Pruner
(
self
.
model
,
cfg
)
pruner
.
compress
()
pruner
.
compress
()
val_metric
=
self
.
val_func
(
*
val_args
,
**
val_kwargs
)
val_metric
=
self
.
val_func
(
*
val_args
,
**
val_kwargs
)
logger
.
info
(
'Layer: %s Sparsity: %.2f Validation Metric: %.4f'
,
logger
.
info
(
'Layer: %s Sparsity: %.2f Validation Metric: %.4f'
,
name
,
sparsity
,
val_metric
)
name
,
real_
sparsity
,
val_metric
)
self
.
sensitivities
[
name
][
sparsity
]
=
val_metric
self
.
sensitivities
[
name
][
sparsity
]
=
val_metric
pruner
.
_unwrap_model
()
pruner
.
_unwrap_model
()
...
...
src/sdk/pynni/nni/nas/benchmarks/nasbench101/query.py
View file @
bf8be1e7
...
@@ -15,7 +15,7 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None,
...
@@ -15,7 +15,7 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None,
arch : dict or None
arch : dict or None
If a dict, it is in the format that is described in
If a dict, it is in the format that is described in
:class:`nni.nas.benchmark.nasbench101.Nb101TrialConfig`. Only trial stats
:class:`nni.nas.benchmark.nasbench101.Nb101TrialConfig`. Only trial stats
matched will be returned. If none, architecture
will be a wildcar
d.
matched will be returned. If none,
all
architecture
s in the database will be matche
d.
num_epochs : int or None
num_epochs : int or None
If int, matching results will be returned. Otherwise a wildcard.
If int, matching results will be returned. Otherwise a wildcard.
isomorphism : boolean
isomorphism : boolean
...
...
src/sdk/pynni/nni/nas/benchmarks/nasbench201/query.py
View file @
bf8be1e7
...
@@ -14,7 +14,7 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_i
...
@@ -14,7 +14,7 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_i
arch : dict or None
arch : dict or None
If a dict, it is in the format that is described in
If a dict, it is in the format that is described in
:class:`nni.nas.benchmark.nasbench201.Nb201TrialConfig`. Only trial stats
:class:`nni.nas.benchmark.nasbench201.Nb201TrialConfig`. Only trial stats
matched will be returned. If none, architecture
will be a wildcar
d.
matched will be returned. If none,
all
architecture
s in the database will be matche
d.
num_epochs : int or None
num_epochs : int or None
If int, matching results will be returned. Otherwise a wildcard.
If int, matching results will be returned. Otherwise a wildcard.
dataset : str or None
dataset : str or None
...
...
src/sdk/pynni/nni/nas/pytorch/mutator.py
View file @
bf8be1e7
...
@@ -162,7 +162,7 @@ class Mutator(BaseMutator):
...
@@ -162,7 +162,7 @@ class Mutator(BaseMutator):
if
self
.
_connect_all
:
if
self
.
_connect_all
:
return
self
.
_all_connect_tensor_reduction
(
mutable
.
reduction
,
return
self
.
_all_connect_tensor_reduction
(
mutable
.
reduction
,
[
op
(
*
args
,
**
kwargs
)
for
op
in
mutable
]),
\
[
op
(
*
args
,
**
kwargs
)
for
op
in
mutable
]),
\
torch
.
ones
(
len
(
mutable
))
torch
.
ones
(
len
(
mutable
))
.
bool
()
def
_map_fn
(
op
,
args
,
kwargs
):
def
_map_fn
(
op
,
args
,
kwargs
):
return
op
(
*
args
,
**
kwargs
)
return
op
(
*
args
,
**
kwargs
)
...
@@ -192,7 +192,7 @@ class Mutator(BaseMutator):
...
@@ -192,7 +192,7 @@ class Mutator(BaseMutator):
"""
"""
if
self
.
_connect_all
:
if
self
.
_connect_all
:
return
self
.
_all_connect_tensor_reduction
(
mutable
.
reduction
,
tensor_list
),
\
return
self
.
_all_connect_tensor_reduction
(
mutable
.
reduction
,
tensor_list
),
\
torch
.
ones
(
mutable
.
n_candidates
)
torch
.
ones
(
mutable
.
n_candidates
)
.
bool
()
mask
=
self
.
_get_decision
(
mutable
)
mask
=
self
.
_get_decision
(
mutable
)
assert
len
(
mask
)
==
mutable
.
n_candidates
,
\
assert
len
(
mask
)
==
mutable
.
n_candidates
,
\
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
mutable
.
n_candidates
)
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
mutable
.
n_candidates
)
...
...
src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_cell.py
View file @
bf8be1e7
...
@@ -79,7 +79,6 @@ class ENASMicroLayer(nn.Module):
...
@@ -79,7 +79,6 @@ class ENASMicroLayer(nn.Module):
"""
"""
def
__init__
(
self
,
num_nodes
,
in_channels_pp
,
in_channels_p
,
out_channels
,
reduction
):
def
__init__
(
self
,
num_nodes
,
in_channels_pp
,
in_channels_p
,
out_channels
,
reduction
):
super
().
__init__
()
super
().
__init__
()
print
(
in_channels_pp
,
in_channels_p
,
out_channels
,
reduction
)
self
.
reduction
=
reduction
self
.
reduction
=
reduction
if
self
.
reduction
:
if
self
.
reduction
:
self
.
reduce0
=
FactorizedReduce
(
in_channels_pp
,
out_channels
,
affine
=
False
)
self
.
reduce0
=
FactorizedReduce
(
in_channels_pp
,
out_channels
,
affine
=
False
)
...
@@ -110,7 +109,7 @@ class ENASMicroLayer(nn.Module):
...
@@ -110,7 +109,7 @@ class ENASMicroLayer(nn.Module):
pprev: torch.Tensor
pprev: torch.Tensor
the output of the previous previous layer
the output of the previous previous layer
prev: torch.Tensor
prev: torch.Tensor
the output of the previous
previous
layer
the output of the previous layer
"""
"""
if
self
.
reduction
:
if
self
.
reduction
:
pprev
,
prev
=
self
.
reduce0
(
pprev
),
self
.
reduce1
(
prev
)
pprev
,
prev
=
self
.
reduce0
(
pprev
),
self
.
reduce1
(
prev
)
...
@@ -160,7 +159,7 @@ class ENASMacroLayer(mutables.MutableScope):
...
@@ -160,7 +159,7 @@ class ENASMacroLayer(mutables.MutableScope):
PoolBranch
(
'avg'
,
in_filters
,
out_filters
,
3
,
1
,
1
),
PoolBranch
(
'avg'
,
in_filters
,
out_filters
,
3
,
1
,
1
),
PoolBranch
(
'max'
,
in_filters
,
out_filters
,
3
,
1
,
1
)
PoolBranch
(
'max'
,
in_filters
,
out_filters
,
3
,
1
,
1
)
])
])
if
prev_labels
>
0
:
if
prev_labels
:
self
.
skipconnect
=
mutables
.
InputChoice
(
choose_from
=
prev_labels
,
n_chosen
=
None
)
self
.
skipconnect
=
mutables
.
InputChoice
(
choose_from
=
prev_labels
,
n_chosen
=
None
)
else
:
else
:
self
.
skipconnect
=
None
self
.
skipconnect
=
None
...
...
src/sdk/pynni/nni/package_utils.py
View file @
bf8be1e7
...
@@ -286,32 +286,9 @@ def create_customized_class_instance(class_params):
...
@@ -286,32 +286,9 @@ def create_customized_class_instance(class_params):
return
instance
return
instance
def
get_python_dir
(
sitepackages_path
):
if
sys
.
platform
==
"win32"
:
return
str
(
Path
(
sitepackages_path
))
else
:
return
str
(
Path
(
sitepackages_path
).
parents
[
2
])
def
get_nni_installation_parent_dir
():
def
get_nni_installation_parent_dir
():
''' Find nni installation parent directory
''' Find nni installation parent directory
'''
'''
def
try_installation_path_sequentially
(
*
sitepackages
):
'''Try different installation path sequentially util nni is found.
Return None if nothing is found
'''
def
_generate_installation_path
(
sitepackages_path
):
python_dir
=
get_python_dir
(
sitepackages_path
)
entry_file
=
os
.
path
.
join
(
python_dir
,
'nni'
,
'main.js'
)
if
os
.
path
.
isfile
(
entry_file
):
return
python_dir
return
None
for
sitepackage
in
sitepackages
:
python_dir
=
_generate_installation_path
(
sitepackage
)
if
python_dir
:
return
python_dir
return
None
if
os
.
getenv
(
'VIRTUAL_ENV'
):
if
os
.
getenv
(
'VIRTUAL_ENV'
):
# if 'virtualenv' package is used, `site` has not attr getsitepackages, so we will instead use VIRTUAL_ENV
# if 'virtualenv' package is used, `site` has not attr getsitepackages, so we will instead use VIRTUAL_ENV
# Note that conda venv will not have VIRTUAL_ENV
# Note that conda venv will not have VIRTUAL_ENV
...
@@ -321,12 +298,23 @@ def get_nni_installation_parent_dir():
...
@@ -321,12 +298,23 @@ def get_nni_installation_parent_dir():
# If system-wide python is used, we will give priority to using `local sitepackage`--"usersitepackages()" given
# If system-wide python is used, we will give priority to using `local sitepackage`--"usersitepackages()" given
# that nni exists there
# that nni exists there
if
python_sitepackage
.
startswith
(
'/usr'
)
or
python_sitepackage
.
startswith
(
'/Library'
):
if
python_sitepackage
.
startswith
(
'/usr'
)
or
python_sitepackage
.
startswith
(
'/Library'
):
python_dir
=
try_installation_path_sequentially
(
site
.
getusersitepackages
(),
site
.
getsitepackages
()
[
0
]
)
python_dir
=
_
try_installation_path_sequentially
(
site
.
getusersitepackages
(),
*
site
.
getsitepackages
())
else
:
else
:
python_dir
=
try_installation_path_sequentially
(
site
.
getsitepackages
()[
0
],
site
.
getusersitepackages
())
python_dir
=
_try_installation_path_sequentially
(
*
site
.
getsitepackages
(),
site
.
getusersitepackages
())
return
python_dir
return
python_dir
def
_try_installation_path_sequentially
(
*
sitepackages
):
'''Try different installation path sequentially util nni is found.
Return None if nothing is found
'''
for
sitepackage
in
sitepackages
:
path
=
Path
(
sitepackage
)
if
len
(
path
.
parents
)
>
2
and
(
path
.
parents
[
2
]
/
'nni'
/
'main.js'
).
is_file
():
return
str
(
path
.
parents
[
2
])
if
(
path
/
'nni'
/
'main.js'
).
is_file
():
return
str
(
path
)
return
None
def
get_nni_installation_path
():
def
get_nni_installation_path
():
''' Find nni installation directory
''' Find nni installation directory
'''
'''
...
...
src/sdk/pynni/tests/test_compressor_tf.py
0 → 100644
View file @
bf8be1e7
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
unittest
import
numpy
as
np
import
tensorflow
as
tf
####
#
# This file tests pruners on 2 models: a classic CNN model, and a naive model with one linear layer
#
# The CNN model is used to test layer detecting and instrumenting.
#
# The naive model is used to test mask calculation.
# It has a single 10x10 linear layer without bias, and `reduce_sum` its result.
# To help predicting pruning result, the linear layer has fixed initial weights:
# [ [ 0.0, 1.0, 2.0, ..., 9.0 ], [0.1, 1.1, 2.1, ..., 9.1 ], ... , [0.9, 1.0, 2.9, ..., 9.9 ] ]
#
####
# This tensor is used as input of 10x10 linear layer, the first dimension is batch size
tensor1x10
=
tf
.
constant
([[
1.0
]
*
10
])
@
unittest
.
skipIf
(
tf
.
__version__
[
0
]
!=
'2'
,
'Skip TF 1.x setup'
)
class
TfCompressorTestCase
(
unittest
.
TestCase
):
def
test_layer_detection
(
self
):
# Conv and dense layers should be compressed, pool and flatten should not.
# This also tests instrumenting functionality.
self
.
_test_layer_detection_on_model
(
CnnModel
())
def
_test_layer_detection_on_model
(
self
,
model
):
pruner
=
pruners
[
'level'
](
model
)
pruner
.
compress
()
layer_types
=
sorted
(
wrapper
.
layer_info
.
type
for
wrapper
in
pruner
.
wrappers
)
assert
layer_types
==
[
'Conv2D'
,
'Dense'
,
'Dense'
],
layer_types
def
test_level_pruner
(
self
):
# prune 90% : 9.0 + 9.1 + ... + 9.9 = 94.5
model
=
build_naive_model
()
pruners
[
'level'
](
model
).
compress
()
x
=
model
(
tensor1x10
)
assert
x
.
numpy
()
==
94.5
try
:
from
tensorflow.keras
import
Model
,
Sequential
from
tensorflow.keras.layers
import
(
Conv2D
,
Dense
,
Flatten
,
MaxPool2D
)
from
nni.compression.tensorflow
import
LevelPruner
pruners
=
{
'level'
:
(
lambda
model
:
LevelPruner
(
model
,
[{
'sparsity'
:
0.9
,
'op_types'
:
[
'default'
]}])),
}
class
CnnModel
(
Model
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
Conv2D
(
filters
=
10
,
kernel_size
=
3
,
activation
=
'relu'
)
self
.
pool
=
MaxPool2D
(
pool_size
=
2
)
self
.
flatten
=
Flatten
()
self
.
fc1
=
Dense
(
units
=
10
,
activation
=
'relu'
)
self
.
fc2
=
Dense
(
units
=
5
,
activation
=
'softmax'
)
def
call
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
pool
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc1
(
x
)
x
=
self
.
fc2
(
x
)
return
x
class
NaiveModel
(
Model
):
def
__init__
(
self
):
super
().
__init__
()
self
.
fc
=
Dense
(
units
=
10
,
use_bias
=
False
)
def
call
(
self
,
x
):
return
tf
.
math
.
reduce_sum
(
self
.
fc
(
x
))
except
Exception
:
pass
def
build_naive_model
():
model
=
NaiveModel
()
model
.
build
(
tensor1x10
.
shape
)
weight
=
[[(
i
+
j
*
0.1
)
for
i
in
range
(
10
)]
for
j
in
range
(
10
)]
model
.
set_weights
([
np
.
array
(
weight
)])
return
model
if
__name__
==
'__main__'
:
unittest
.
main
()
src/sdk/pynni/tests/test_model_speedup.py
View file @
bf8be1e7
...
@@ -145,7 +145,7 @@ class SpeedupTestCase(TestCase):
...
@@ -145,7 +145,7 @@ class SpeedupTestCase(TestCase):
assert
model
.
backbone2
.
fc1
.
in_features
==
int
(
orig_model
.
backbone2
.
fc1
.
in_features
*
SPARSITY
)
assert
model
.
backbone2
.
fc1
.
in_features
==
int
(
orig_model
.
backbone2
.
fc1
.
in_features
*
SPARSITY
)
def
test_speedup_integration
(
self
):
def
test_speedup_integration
(
self
):
for
model_name
in
[
'resnet18'
,
'squeezenet1_1'
,
'mobilenet_v2'
,
'densenet121'
,
'inception_v3'
]:
for
model_name
in
[
'resnet18'
,
'squeezenet1_1'
,
'mobilenet_v2'
,
'densenet121'
,
'densenet169'
,
'inception_v3'
]:
Model
=
getattr
(
models
,
model_name
)
Model
=
getattr
(
models
,
model_name
)
net
=
Model
(
pretrained
=
True
,
progress
=
False
).
to
(
device
)
net
=
Model
(
pretrained
=
True
,
progress
=
False
).
to
(
device
)
speedup_model
=
Model
().
to
(
device
)
speedup_model
=
Model
().
to
(
device
)
...
...
src/webui/src/components/Modals/Compare.tsx
View file @
bf8be1e7
...
@@ -85,8 +85,10 @@ class Compare extends React.Component<CompareProps, {}> {
...
@@ -85,8 +85,10 @@ class Compare extends React.Component<CompareProps, {}> {
containLabel
:
true
containLabel
:
true
},
},
legend
:
{
legend
:
{
// more than 10 trials will hide legend
type
:
'
scroll
'
,
data
:
idsList
.
length
>
10
?
null
:
idsList
right
:
40
,
left
:
idsList
.
length
>
6
?
80
:
null
,
data
:
idsList
},
},
xAxis
:
{
xAxis
:
{
type
:
'
category
'
,
type
:
'
category
'
,
...
@@ -135,8 +137,17 @@ class Compare extends React.Component<CompareProps, {}> {
...
@@ -135,8 +137,17 @@ class Compare extends React.Component<CompareProps, {}> {
isComplexSearchSpace
=
(
typeof
parameterList
[
0
][
parameterKeys
[
0
]]
===
'
object
'
)
isComplexSearchSpace
=
(
typeof
parameterList
[
0
][
parameterKeys
[
0
]]
===
'
object
'
)
?
true
:
false
;
?
true
:
false
;
}
}
const
width
=
this
.
getWebUIWidth
();
let
scrollClass
;
if
(
width
>
1200
)
{
scrollClass
=
idList
.
length
>
3
?
'
flex
'
:
''
;
}
else
if
(
width
<
700
)
{
scrollClass
=
idList
.
length
>
1
?
'
flex
'
:
''
;
}
else
{
scrollClass
=
idList
.
length
>
2
?
'
flex
'
:
''
;
}
return
(
return
(
<
table
className
=
"
compare-modal-table
"
>
<
table
className
=
{
`
compare-modal-table
${
scrollClass
}
`
}
>
<
tbody
>
<
tbody
>
<
tr
>
<
tr
>
<
td
className
=
"column"
>
Id
</
td
>
<
td
className
=
"column"
>
Id
</
td
>
...
@@ -200,6 +211,10 @@ class Compare extends React.Component<CompareProps, {}> {
...
@@ -200,6 +211,10 @@ class Compare extends React.Component<CompareProps, {}> {
);
);
}
}
getWebUIWidth
=
():
number
=>
{
return
window
.
innerWidth
;
}
componentDidMount
():
void
{
componentDidMount
():
void
{
this
.
_isCompareMount
=
true
;
this
.
_isCompareMount
=
true
;
}
}
...
...
src/webui/src/components/public-child/PaiTrialChild.tsx
View file @
bf8be1e7
import
*
as
React
from
'
react
'
;
import
*
as
React
from
'
react
'
;
import
{
DOWNLOAD_IP
}
from
'
../../static/const
'
;
import
{
DOWNLOAD_IP
}
from
'
../../static/const
'
;
import
LogPathChild
from
'
./LogPathChild
'
;
interface
PaiTrialChildProps
{
interface
PaiTrialChildProps
{
logString
:
string
;
logString
:
string
;
...
@@ -21,7 +22,7 @@ class PaiTrialChild extends React.Component<PaiTrialChildProps, {}> {
...
@@ -21,7 +22,7 @@ class PaiTrialChild extends React.Component<PaiTrialChildProps, {}> {
{
{
logString
===
''
logString
===
''
?
?
<
div
/>
null
:
:
<
div
>
<
div
>
{
{
...
@@ -33,10 +34,13 @@ class PaiTrialChild extends React.Component<PaiTrialChildProps, {}> {
...
@@ -33,10 +34,13 @@ class PaiTrialChild extends React.Component<PaiTrialChildProps, {}> {
href
=
{
`
${
DOWNLOAD_IP
}
/trial_
${
id
}
.log`
}
href
=
{
`
${
DOWNLOAD_IP
}
/trial_
${
id
}
.log`
}
style
=
{
{
marginRight
:
10
}
}
style
=
{
{
marginRight
:
10
}
}
>
>
t
rial stdout
T
rial stdout
</
a
>
</
a
>
:
:
<
span
>
trial stdout:
{
logString
}
</
span
>
<
LogPathChild
eachLogpath
=
{
logString
}
logName
=
"Trial stdout:"
/>
}
}
</
div
>
</
div
>
}
}
...
...
src/webui/src/components/public-child/PaiTrialLog.tsx
View file @
bf8be1e7
...
@@ -42,7 +42,7 @@ class PaitrialLog extends React.Component<PaitrialLogProps, {}> {
...
@@ -42,7 +42,7 @@ class PaitrialLog extends React.Component<PaitrialLogProps, {}> {
>
>
Trial stdout
Trial stdout
</
a
>
</
a
>
<
a
target
=
"_blank"
rel
=
"noopener noreferrer"
href
=
{
logStr
.
split
(
'
,
'
)[
1
]
}
>
hdfsL
og
</
a
>
<
a
target
=
"_blank"
rel
=
"noopener noreferrer"
href
=
{
logStr
.
split
(
'
,
'
)[
1
]
}
>
NFS l
og
</
a
>
</
div
>
</
div
>
:
:
<
div
>
<
div
>
...
@@ -52,7 +52,7 @@ class PaitrialLog extends React.Component<PaitrialLogProps, {}> {
...
@@ -52,7 +52,7 @@ class PaitrialLog extends React.Component<PaitrialLogProps, {}> {
/>
/>
<
LogPathChild
<
LogPathChild
eachLogpath
=
{
logStr
.
split
(
'
,
'
)[
1
]
}
eachLogpath
=
{
logStr
.
split
(
'
,
'
)[
1
]
}
logName
=
"Log on
HD
FS:"
logName
=
"Log on
N
FS:"
/>
/>
</
div
>
</
div
>
}
}
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment